Kd Tree Revisited
After giving it a bit of thought, I’ve found a way to simplify the nearest
neighbour search (i.e: the closest
method) for the KdTree
I implemented in
my previous post.
The improvement
That post implemented the nearest neighbour search by keeping track of the
tree’s boundaries (through AABB
), and each of its sub-trees (through
AABB.split
), and testing for the early exit condition by computing the
distance of the search’s origin to each sub-tree’s boundaries.
Instead of explicitly keeping track of each sub-tree’s boundaries, we can implicitly compute it when recursing down the tree.
To check for the distance between the queried point and the splitting plane of inner nodes: we simply need to project the origin onto that plane, thus giving us a minimal bound on the distance of the points stored on the other side.
This can be easily computed from the axis
and mid
values which are stored in
the inner nodes: to project the node on the plane we simply replace its
coordinate for this axis by mid
.
Simplified search
With that out of the way, let’s now see how closest
can be implemented without
needing to track the tree’s AABB
at the root:
# Wrapper type for closest points, ordered by `distance`
@dataclasses.dataclass(order=True)
class ClosestPoint[T](NamedTuple):
point: Point = field(compare=False)
value: T = field(compare=False)
distance: float
class KdTree[T]:
def closest(self, point: Point, n: int = 1) -> list[ClosestPoint[T]]:
assert n > 0
res = MaxHeap()
# Instead of passing an `AABB`, we give an initial projection point,
# the query origin itself (since we haven't visited any split node yet)
self._root.closest(point, res, n, point)
return sorted(res)
class KdNode[T]:
def closest(
self,
point: Point,
out: MaxHeap[ClosestPoint[T]],
n: int,
projection: Point,
) -> None:
# Same implementation
self.inner.closest(point, out, n, bounds)
class KdLeafNode[T]:
def closest(
self,
point: Point,
out: MaxHeap[ClosestPoint[T]],
n: int,
projection: Point,
) -> None:
# Same implementation
for p, val in self.points.items():
item = ClosestPoint(p, val, dist(p, point))
if len(out) < n:
out.push(item)
elif out.peek().distance > item.distance:
out.pushpop(item)
class KdSplitNode[T]:
def closest(
self,
point: Point,
out: list[ClosestPoint[T]],
n: int,
projection: Point,
) -> None:
index = self._index(point)
self.children[index].closest(point, out, n, projection)
# Project onto the splitting plane, for a minimum distance to its points
projection = projection.replace(self.axis, self.mid)
# If we're at capacity and can't possibly find any closer points, exit
if len(out) == n and dist(point, projection) > out.peek().distance:
return
# Otherwise recurse on the other side to check for nearer neighbours
self.children[1 - index].closest(point, out, n, projection)
As you can see, the main difference is in KdSplitNode
’s implementation, where
we can quickly compute the minimum distance between the search’s origin and all
potential points in that subspace.