Add various iterator.* functions

This commit is contained in:
Lars Mueller 2022-10-01 12:56:24 +02:00
parent d72e589a62
commit 0d864f0065

@ -1,12 +1,97 @@
--[[
Iterators are always the *last* argument(s) to all functions here,
which differs from other modules which take what they operate on as first argument.
This is because iterators consist of three variables - iterator function, state & control variable -
and wrapping them (using a table, a closure or the like) would be rather inconvenient.
Having them as the last argument allows to just pass in the three variables returned by functions such as `[i]pairs`.
Additionally, putting functions first - although syntactically inconvenient - is consistent with Python and Lisp.
]]
local coroutine_create, coroutine_resume, coroutine_yield, coroutine_status, unpack, select local coroutine_create, coroutine_resume, coroutine_yield, coroutine_status, unpack, select
= coroutine.create, coroutine.resume, coroutine.yield, coroutine.status, unpack, select = coroutine.create, coroutine.resume, coroutine.yield, coroutine.status, unpack, select
local add = modlib.func.add local identity, not_, add = modlib.func.identity, modlib.func.not_, modlib.func.add
--+ For all functions which aggregate over single values, use modlib.table.ivalues - not ipairs - for lists! --+ For all functions which aggregate over single values, use modlib.table.ivalues - not ipairs - for lists!
--+ Otherwise they will be applied to the indices. --+ Otherwise they will be applied to the indices.
local iterator = {} local iterator = {}
function iterator.wrap(iterator, state, control_var)
local function update_control_var(...)
control_var = ...
return ...
end
return function()
return update_control_var(iterator(state, control_var))
end
end
iterator.closure = iterator.wrap
iterator.make_stateful = iterator.wrap
function iterator.filter(predicate, iterator, state, control_var)
local function _filter(...)
local cvar = ...
if cvar == nil then
return
end
if predicate(...) then
return ...
end
return _filter(iterator(state, cvar))
end
return function(state, control_var)
return _filter(iterator(state, control_var))
end, state, control_var
end
function iterator.truthy(...)
return iterator.filter(identity, ...)
end
function iterator.falsy(...)
return iterator.filter(not_, ...)
end
function iterator.map(map_func, iterator, state, control_var)
local function _map(...)
control_var = ... -- update control var
if control_var == nil then return end
return map_func(...)
end
return function()
return _map(iterator(state, control_var))
end
end
function iterator.map_values(map_func, iterator, state, control_var)
local function _map_values(cvar, ...)
if cvar == nil then return end
return cvar, map_func(...)
end
return function(state, control_var)
return _map_values(iterator(state, control_var))
end, state, control_var
end
-- Iterator must be restartable
function iterator.rep(times, iterator, state, control_var)
times = times or 1
if times == 1 then
return iterator, state, control_var
end
local function _rep(cvar, ...)
if cvar == nil then
times = times - 1
if times == 0 then return end
return _rep(iterator(state, control_var))
end
return cvar, ...
end
return function(state, control_var)
return _rep(iterator(state, control_var))
end, state, control_var
end
-- Equivalent to `for x, y, z in iterator, state, ... do callback(x, y, z) end` -- Equivalent to `for x, y, z in iterator, state, ... do callback(x, y, z) end`
function iterator.foreach(callback, iterator, state, ...) function iterator.foreach(callback, iterator, state, ...)
local function loop(...) local function loop(...)
@ -75,6 +160,8 @@ function iterator.reduce(binary_func, iterator, state, control_var)
end end
iterator.fold = iterator.reduce iterator.fold = iterator.reduce
-- TODO iterator.find(predicate, iterator, state, control_var)
function iterator.any(...) function iterator.any(...)
for val in ... do for val in ... do
if val then return true end if val then return true end
@ -101,6 +188,62 @@ end
-- TODO iterator.max -- TODO iterator.max
function iterator.empty(iterator, state, control_var)
return iterator(state, control_var) == nil
end
function iterator.first(iterator, state, control_var)
return iterator(state, control_var)
end
function iterator.last(iterator, state, control_var)
-- Storing a vararg in a table seems to be necessary: https://stackoverflow.com/questions/73914273/
-- This could be optimized further for memory by keeping the same table across calls,
-- but that might cause issues with multiple coroutines calling this
local last, last_n = {}, 0
local function _last(...)
local cvar = ...
if cvar == nil then
return unpack(last, 1, last_n)
end
-- Write vararg to table: Avoid the creation of a garbage table every iteration by reusing the same table
last_n = select("#", ...)
for i = 1, last_n do
last[i] = select(i, ...)
end
return _last(iterator(state, cvar))
end
return _last(iterator(state, control_var))
end
-- Converts a vararg starting with `nil` (end of loop control variable) into nothing
local function nil_to_nothing(...)
if ... == nil then return end
return ...
end
function iterator.select(n, iterator, state, control_var)
for _ = 1, n - 1 do
control_var = iterator(state, control_var)
if control_var == nil then return end
end
-- Either all values returned by the n-th call iteration
-- or nothing if the iterator holds fewer than `n` values
return nil_to_nothing(iterator(state, control_var))
end
function iterator.limit(count, iterator, state, control_var)
return function(state, control_var)
count = count - 1
if count < 0 then return end
return iterator(state, control_var)
end, state, control_var
end
function iterator.count(...) function iterator.count(...)
local count = 0 local count = 0
for _ in ... do for _ in ... do