From 7f58887ae33893c981fbdff23d4e1fa4a11c32e4 Mon Sep 17 00:00:00 2001
From: Jude Melton-Houghton <jwmhjwmh@gmail.com>
Date: Tue, 10 May 2022 16:37:33 -0400
Subject: [PATCH] Support packing arbitrary graphs (#12289)

---
 games/devtest/mods/unittests/async_env.lua |  35 +++++--
 src/script/common/c_packer.cpp             | 103 ++++++++++++---------
 src/script/common/c_packer.h               |   5 +-
 3 files changed, 89 insertions(+), 54 deletions(-)

diff --git a/games/devtest/mods/unittests/async_env.lua b/games/devtest/mods/unittests/async_env.lua
index aff1fc4d9..3a21bd9e2 100644
--- a/games/devtest/mods/unittests/async_env.lua
+++ b/games/devtest/mods/unittests/async_env.lua
@@ -60,15 +60,34 @@ local function test_object_passing()
 	local tmp = core.serialize_roundtrip(test_object)
 	assert(deepequal(test_object, tmp))
 
-	-- Circular key, should error
-	tmp = {"foo", "bar"}
-	tmp[tmp] = true
-	assert(not pcall(core.serialize_roundtrip, tmp))
+	local circular_key = {"foo", "bar"}
+	circular_key[circular_key] = true
+	tmp = core.serialize_roundtrip(circular_key)
+	assert(tmp[1] == "foo")
+	assert(tmp[2] == "bar")
+	assert(tmp[tmp] == true)
 
-	-- Circular value, should error
-	tmp = {"foo"}
-	tmp[2] = tmp
-	assert(not pcall(core.serialize_roundtrip, tmp))
+	local circular_value = {"foo"}
+	circular_value[2] = circular_value
+	tmp = core.serialize_roundtrip(circular_value)
+	assert(tmp[1] == "foo")
+	assert(tmp[2] == tmp)
+
+	-- Two-segment cycle
+	local cycle_seg_1, cycle_seg_2 = {}, {}
+	cycle_seg_1[1] = cycle_seg_2
+	cycle_seg_2[1] = cycle_seg_1
+	tmp = core.serialize_roundtrip(cycle_seg_1)
+	assert(tmp[1][1] == tmp)
+
+	-- Duplicated value without a cycle
+	local acyclic_dup_holder = {}
+	tmp = ItemStack("")
+	acyclic_dup_holder[tmp] = tmp
+	tmp = core.serialize_roundtrip(acyclic_dup_holder)
+	for k, v in pairs(tmp) do
+		assert(rawequal(k, v))
+	end
 end
 unittests.register("test_object_passing", test_object_passing)
 
diff --git a/src/script/common/c_packer.cpp b/src/script/common/c_packer.cpp
index fc5277330..ede00c758 100644
--- a/src/script/common/c_packer.cpp
+++ b/src/script/common/c_packer.cpp
@@ -123,6 +123,7 @@ namespace {
 		size_t idx;
 		VectorRef(std::vector<T> *vec, size_t idx) : vec(vec), idx(idx) {}
 	public:
+		constexpr VectorRef() : vec(nullptr), idx(0) {}
 		static VectorRef<T> front(std::vector<T> &vec) {
 			return VectorRef(&vec, 0);
 		}
@@ -131,6 +132,7 @@ namespace {
 		}
 		T &operator*() { return (*vec)[idx]; }
 		T *operator->() { return &(*vec)[idx]; }
+		operator bool() const { return vec != nullptr; }
 	};
 
 	struct Packer {
@@ -252,38 +254,27 @@ static bool find_packer(lua_State *L, int idx, PackerTuple &out)
 // Packing implementation
 //
 
-// recursively goes through the structure and ensures there are no circular references
-static void pack_validate(lua_State *L, int idx, std::unordered_set<const void*> &seen)
+static VectorRef<PackedInstr> record_object(lua_State *L, int idx, PackedValue &pv,
+		std::unordered_map<const void *, s32> &seen)
 {
-#ifndef NDEBUG
-	StackChecker checker(L);
-	assert(idx > 0);
-#endif
-
-	if (lua_type(L, idx) != LUA_TTABLE)
-		return;
-
 	const void *ptr = lua_topointer(L, idx);
 	assert(ptr);
-
-	if (seen.find(ptr) != seen.end())
-		throw LuaError("Circular references cannot be packed (yet)");
-	seen.insert(ptr);
-
-	lua_checkstack(L, 5);
-	lua_pushnil(L);
-	while (lua_next(L, idx) != 0) {
-		// key at -2, value at -1
-		pack_validate(L, absidx(L, -2), seen);
-		pack_validate(L, absidx(L, -1), seen);
-
-		lua_pop(L, 1);
+	auto found = seen.find(ptr);
+	if (found == seen.end()) {
+		seen[ptr] = pv.i.size();
+		return VectorRef<PackedInstr>();
 	}
-
-	seen.erase(ptr);
+	s32 ref = found->second;
+	assert(ref < (s32)pv.i.size());
+	// reuse the value from first time
+	auto r = emplace(pv, INSTR_PUSHREF);
+	r->ref = ref;
+	pv.i[ref].keep_ref = true;
+	return r;
 }
 
-static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, PackedValue &pv)
+static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, PackedValue &pv,
+		std::unordered_map<const void *, s32> &seen)
 {
 #ifndef NDEBUG
 	StackChecker checker(L);
@@ -313,10 +304,17 @@ static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, Packed
 			r->sdata.assign(str, len);
 			return r;
 		}
-		case LUA_TTABLE:
+		case LUA_TTABLE: {
+			auto r = record_object(L, idx, pv, seen);
+			if (r)
+				return r;
 			break; // execution continues
+		}
 		case LUA_TFUNCTION: {
-			auto r = emplace(pv, LUA_TFUNCTION);
+			auto r = record_object(L, idx, pv, seen);
+			if (r)
+				return r;
+			r = emplace(pv, LUA_TFUNCTION);
 			call_string_dump(L, idx);
 			size_t len;
 			const char *str = lua_tolstring(L, -1, &len);
@@ -326,11 +324,14 @@ static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, Packed
 			return r;
 		}
 		case LUA_TUSERDATA: {
+			auto r = record_object(L, idx, pv, seen);
+			if (r)
+				return r;
 			PackerTuple ser;
 			if (!find_packer(L, idx, ser))
 				throw LuaError("Cannot serialize unsupported userdata");
 			pv.contains_userdata = true;
-			auto r = emplace(pv, LUA_TUSERDATA);
+			r = emplace(pv, LUA_TUSERDATA);
 			r->sdata = ser.first;
 			r->ptrdata = ser.second.fin(L, idx);
 			return r;
@@ -360,8 +361,8 @@ static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, Packed
 		// check if we can use a shortcut
 		if (can_set_into(ktype, vtype) && suitable_key(L, -2)) {
 			// push only the value
-			auto rval = pack_inner(L, absidx(L, -1), vidx, pv);
-			rval->pop = vtype != LUA_TTABLE;
+			auto rval = pack_inner(L, absidx(L, -1), vidx, pv, seen);
+			rval->pop = rval->type != LUA_TTABLE;
 			// and where to put it:
 			rval->set_into = vi_table;
 			if (ktype == LUA_TSTRING)
@@ -375,9 +376,9 @@ static VectorRef<PackedInstr> pack_inner(lua_State *L, int idx, int vidx, Packed
 			}
 		} else {
 			// push the key and value
-			pack_inner(L, absidx(L, -2), vidx, pv);
+			pack_inner(L, absidx(L, -2), vidx, pv, seen);
 			vidx++;
-			pack_inner(L, absidx(L, -1), vidx, pv);
+			pack_inner(L, absidx(L, -1), vidx, pv, seen);
 			vidx++;
 			// push an instruction to set them
 			auto ri1 = emplace(pv, INSTR_SETTABLE);
@@ -400,13 +401,9 @@ PackedValue *script_pack(lua_State *L, int idx)
 	if (idx < 0)
 		idx = absidx(L, idx);
 
-	std::unordered_set<const void*> seen;
-	pack_validate(L, idx, seen);
-	assert(seen.size() == 0);
-
-	// Actual serialization
 	PackedValue pv;
-	pack_inner(L, idx, 1, pv);
+	std::unordered_map<const void *, s32> seen;
+	pack_inner(L, idx, 1, pv, seen);
 
 	return new PackedValue(std::move(pv));
 }
@@ -417,18 +414,21 @@ PackedValue *script_pack(lua_State *L, int idx)
 
 void script_unpack(lua_State *L, PackedValue *pv)
 {
+	lua_newtable(L); // table at index top to track ref indices -> objects
 	const int top = lua_gettop(L);
 	int ctr = 0;
 
-	for (auto &i : pv->i) {
+	for (size_t packed_idx = 0; packed_idx < pv->i.size(); packed_idx++) {
+		auto &i = pv->i[packed_idx];
+
 		// If leaving values on stack make sure there's space (every 5th iteration)
 		if (!i.pop && (ctr++) >= 5) {
 			lua_checkstack(L, 5);
 			ctr = 0;
 		}
 
-		/* Instructions */
 		switch (i.type) {
+			/* Instructions */
 			case INSTR_SETTABLE:
 				lua_pushvalue(L, top + i.sidata1); // key
 				lua_pushvalue(L, top + i.sidata2); // value
@@ -448,12 +448,12 @@ void script_unpack(lua_State *L, PackedValue *pv)
 				if (i.sidata2 > 0)
 					lua_remove(L, top + i.sidata2);
 				continue;
-			default:
+			case INSTR_PUSHREF:
+				lua_pushinteger(L, i.ref);
+				lua_rawget(L, top);
 				break;
-		}
 
-		/* Lua types */
-		switch (i.type) {
+			/* Lua types */
 			case LUA_TNIL:
 				lua_pushnil(L);
 				break;
@@ -479,11 +479,18 @@ void script_unpack(lua_State *L, PackedValue *pv)
 				i.ptrdata = nullptr; // ownership taken by callback
 				break;
 			}
+
 			default:
 				assert(0);
 				break;
 		}
 
+		if (i.keep_ref) {
+			lua_pushinteger(L, packed_idx);
+			lua_pushvalue(L, -2);
+			lua_rawset(L, top);
+		}
+
 		if (i.set_into) {
 			if (!i.pop)
 				lua_pushvalue(L, -1);
@@ -501,6 +508,7 @@ void script_unpack(lua_State *L, PackedValue *pv)
 	pv->contains_userdata = false;
 	// leave exactly one value on the stack
 	lua_settop(L, top+1);
+	lua_remove(L, top);
 }
 
 //
@@ -541,6 +549,9 @@ void script_dump_packed(const PackedValue *val)
 			case INSTR_POP:
 				printf(i.sidata2 ? "POP(%d, %d)" : "POP(%d)", i.sidata1, i.sidata2);
 				break;
+			case INSTR_PUSHREF:
+				printf("PUSHREF(%d)", i.ref);
+				break;
 			case LUA_TNIL:
 				printf("nil");
 				break;
@@ -574,6 +585,8 @@ void script_dump_packed(const PackedValue *val)
 			else
 				printf(", into=%d", i.set_into);
 		}
+		if (i.keep_ref)
+			printf(", keep_ref");
 		if (i.pop)
 			printf(", pop");
 		printf(")\n");
diff --git a/src/script/common/c_packer.h b/src/script/common/c_packer.h
index 8bccca98d..ee732be86 100644
--- a/src/script/common/c_packer.h
+++ b/src/script/common/c_packer.h
@@ -36,6 +36,7 @@ extern "C" {
 
 #define INSTR_SETTABLE (-10)
 #define INSTR_POP      (-11)
+#define INSTR_PUSHREF  (-12)
 
 /**
  * Represents a single instruction that pushes a new value or works with existing ones.
@@ -44,6 +45,7 @@ struct PackedInstr
 {
 	s16 type; // LUA_T* or INSTR_*
 	u16 set_into; // set into table on stack
+	bool keep_ref; // is referenced later by INSTR_PUSHREF?
 	bool pop; // remove from stack?
 	union {
 		bool bdata; // boolean: value
@@ -60,6 +62,7 @@ struct PackedInstr
 			s32 sidata1, sidata2;
 		};
 		void *ptrdata; // userdata: implementation defined
+		s32 ref; // PUSHREF: index of referenced instr
 	};
 	/*
 		- string: value
@@ -69,7 +72,7 @@ struct PackedInstr
 	*/
 	std::string sdata;
 
-	PackedInstr() : type(0), set_into(0), pop(false) {}
+	PackedInstr() : type(0), set_into(0), keep_ref(false), pop(false) {}
 };
 
 /**