Add func.for_generator

This commit is contained in:
Lars Mueller 2021-07-14 10:26:57 +02:00
parent 2ba0b0555f
commit a86b24eef5
2 changed files with 37 additions and 3 deletions

@ -1,5 +1,6 @@
-- Localize globals
local error, modlib, unpack = error, modlib, unpack
local error, coroutine, modlib, unpack
= error, coroutine, modlib, unpack
-- Set environment
local _ENV = {}
@ -47,6 +48,27 @@ function iterate(callback, iterator, ...)
return _iterate(iterator(...))
end
function for_generator(caller, ...)
local co = coroutine.create(function(...)
return caller(function(...)
return coroutine.yield(...)
end, ...)
end)
local args = {...}
return function()
if coroutine.status(co) == "dead" then
return
end
local function _iterate(status, ...)
if not status then
error((...))
end
return ...
end
return _iterate(coroutine.resume(co, unpack(args)))
end
end
-- Does not use select magic, stops at the first nil value
function aggregate(binary_func, total, ...)
if total == nil then return end

@ -19,10 +19,22 @@ setfenv(1, setmetatable({}, {
-- func
do
local tab = {a = 1, b = 2}
func.iterate(function(key, value)
local function check_entry(key, value)
assert(tab[key] == value)
tab[key] = nil
end, pairs, tab)
end
func.iterate(check_entry, pairs, tab)
assert(next(tab) == nil)
tab = {a = 1, b = 2}
local function pairs_callback(callback, tab)
for k, v in pairs(tab) do
callback(k, v)
end
end
for k, v in func.for_generator(pairs_callback, tab) do
check_entry(k, v)
end
assert(next(tab) == nil)
assert(func.aggregate(func.add, 1, 2, 3) == 6)
end