diff --git a/func.lua b/func.lua index 8a287e6..238fc2e 100644 --- a/func.lua +++ b/func.lua @@ -1,5 +1,6 @@ -- Localize globals -local error, modlib, unpack = error, modlib, unpack +local error, coroutine, modlib, unpack + = error, coroutine, modlib, unpack -- Set environment local _ENV = {} @@ -47,6 +48,27 @@ function iterate(callback, iterator, ...) return _iterate(iterator(...)) end +function for_generator(caller, ...) + local co = coroutine.create(function(...) + return caller(function(...) + return coroutine.yield(...) + end, ...) + end) + local args = {...} + return function() + if coroutine.status(co) == "dead" then + return + end + local function _iterate(status, ...) + if not status then + error((...)) + end + return ... + end + return _iterate(coroutine.resume(co, unpack(args))) + end +end + -- Does not use select magic, stops at the first nil value function aggregate(binary_func, total, ...) if total == nil then return end diff --git a/test.lua b/test.lua index 2efc621..9b93834 100644 --- a/test.lua +++ b/test.lua @@ -19,10 +19,22 @@ setfenv(1, setmetatable({}, { -- func do local tab = {a = 1, b = 2} - func.iterate(function(key, value) + local function check_entry(key, value) assert(tab[key] == value) tab[key] = nil - end, pairs, tab) + end + func.iterate(check_entry, pairs, tab) + assert(next(tab) == nil) + + tab = {a = 1, b = 2} + local function pairs_callback(callback, tab) + for k, v in pairs(tab) do + callback(k, v) + end + end + for k, v in func.for_generator(pairs_callback, tab) do + check_entry(k, v) + end assert(next(tab) == nil) assert(func.aggregate(func.add, 1, 2, 3) == 6) end