diff --git a/kdtree.lua b/kdtree.lua index 8ea8e1c..bf9935e 100644 --- a/kdtree.lua +++ b/kdtree.lua @@ -21,15 +21,15 @@ function new(vectors, distance) end local self = builder(vectors, 1) self.distance = distance - return self + return setmetatable(self, metatable) end function get_nearest_neighbor(self, vector) local min_distance = math.huge local nearest_neighbor local distance_func = self.distance - local axis = tree.axis local function visit(tree) + local axis = tree.axis if tree.value ~= nil then local distance = distance_func(tree.value, vector) if distance < min_distance then @@ -42,7 +42,7 @@ function get_nearest_neighbor(self, vector) if vector[axis] < tree.pivot[axis] then this_side, other_side = other_side, this_side end visit(this_side) if tree.pivot then - local dist = math.abs(tree.pivot[axis] - color[axis]) + local dist = math.abs(tree.pivot[axis] - vector[axis]) if dist <= min_distance then visit(other_side) end end end