Skip to main content

k-d Tree

The k-d Tree is a useful way to map points in space and make them efficient to query.

I ran into them during my studies in graphics, as they are one of the possible acceleration structures for ray-casting operations.

Implementation

As usual, this will be in Python, though its lack of proper discriminated enums makes it more verbose than would otherwise be necessary.

Pre-requisites

Let’s first define what kind of space our k-d Tree is dealing with. In this instance $k = 3$ just like in the normal world.

class Point(NamedTuple):
    x: float
    y: float
    z: float

class Axis(IntEnum):
    X = 0
    Y = 1
    Z = 2

    def next(self) -> Axis:
        # Each level of the tree is split along a different axis
        return Axis((self + 1) % 3)

Representation

The tree is represented by KdTree, each of its leaf nodes is a KdLeafNode and its inner nodes are KdSplitNodes.

For each point in space, the tree can also keep track of an associated value, similar to a dictionary or other mapping data structure. Hence we will make our KdTree generic to this mapped type T.

Leaf node

A leaf node contains a number of points that were added to the tree. For each point, we also track their mapped value, hence the dict[Point, T].

class KdLeafNode[T]:
    points: dict[Point, T]

    def __init__(self):
        self.points = {}

Split node

An inner node must partition the space into two sub-spaces along a given axis and mid-point (thus defining a plane). All points that are “to the left” of the plane will be kept in one child, while all the points “to the right” will be in the other. Similar to a Binary Search Tree’s inner nodes.

class KdSplitNode[T]:
    axis: Axis
    mid: float
    children: tuple[KdTreeNode[T], KdTreeNode[T]]

    # Convenience function to index into the child which contains `point`
    def _index(self, point: Point) -> int:
        return 0 if point[self.axis] <= self.mid else 1

Tree

The tree itself is merely a wrapper around its inner nodes.

Once annoying issue about writing this in Python is the lack of proper discriminated enum types. So we need to create a wrapper type for the nodes (KdNode) to allow for splitting when updating the tree.

class KdNode[T]:
    # Wrapper around leaf/inner nodes, the poor man's discriminated enum
    inner: KdLeafNode[T] | KdSplitNode[T]

    def __init__(self):
        self.inner = KdLeafNode()

    # Convenience constructor used when splitting a node
    @classmethod
    def from_items(cls, items: Iterable[tuple[Point, T]]) -> KdNode[T]:
        res = cls()
        res.inner.points.update(items)
        return res

class KdTree[T]:
    _root: KdNode[T]

    def __init__(self):
        # Tree starts out empty
        self._root = KdNode()

Inserting a point

To add a point to the tree, we simply recurse from node to node, similar to a BST’s insertion algorithm. Once we’ve found the correct leaf node to insert our point into, we simply do so.

If that leaf node goes over the maximum number of points it can store, we must then split it along an axis, cycling between X, Y, and Z at each level of the tree (i.e: splitting along the X axis on the first level, then Y on the second, then Z after that, and then X, etc…).

# How many points should be stored in a leaf node before being split
MAX_CAPACITY = 32

def median(values: Iterable[float]) -> float:
    sorted_values = sorted(values)
    mid_point = len(sorted_values) // 2
    if len(sorted_values) % 2 == 1:
        return sorted_values[mid_point]
    a, b = sorted_values[mid_point], sorted_values[mid_point + 1]
    return a + (b - a) / 2

def partition[T](
    pred: Callable[[T], bool],
    iterable: Iterable[T]
) -> tuple[list[T], list[T]]:
    truths, falses = [], []
    for v in iterable:
        (truths if pred(v) else falses).append(v)
    return truths, falses

def split_leaf[T](node: KdLeafNode[T], axis: Axis) -> KdSplitNode[T]:
    # Find the median value for the given axis
    mid = median(p[axis] for p in node.points)
    # Split into left/right children according to the mid-point and axis
    left, right = partition(lambda kv: kv[0][axis] <= mid, node.points.items())
    return KdSplitNode(
        split_axis,
        mid,
        (KdNode.from_items(left), KdNode.from_items(right)),
    )

class KdTree[T]:
    def insert(self, point: Point, val: T) -> bool:
        # Forward to the root node, choose `X` as the first split axis
        return self._root.insert(point, val, Axis.X)

class KdLeafNode[T]:
    def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
        # Check whether we're overwriting a previous value
        was_mapped = point in self.points
        # Store the corresponding value
        self.points[point] = val
        # Return whether we've performed an overwrite
        return was_mapped

class KdSplitNode[T]:
    def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
        # Find the child which contains the point
        child = self.children[self._index(point)]
        # Recurse into it, choosing the next split axis
        return child.insert(point, val, split_axis.next())

class KdNode[T]:
    def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
        # Add the point to the wrapped node...
        res = self.inner.insert(point, val, split_axis)
        # ... And take care of splitting leaf nodes when necessary
        if (
            isinstance(self.inner, KdLeafNode)
            and len(self.inner.points) > MAX_CAPACITY
        ):
            self.inner = split_leaf(self.inner, split_axis)
        return res

Searching for a point

Looking for a given point in the tree look very similar to a BST’s search, each leaf node dividing the space into two sub-spaces, only one of which contains the point.

class KdTree[T]:
    def lookup(self, point: Point) -> T | None:
        # Forward to the root node
        return self._root.lookup(point)

class KdNode[T]:
    def lookup(self, point: Point) -> T | None:
        # Forward to the wrapped node
        return self.inner.lookup(point)

class KdLeafNode[T]:
    def lookup(self, point: Point) -> T | None:
        # Simply check whether we've stored the point in this leaf
        return self.points.get(point)

class KdSplitNode[T]:
    def lookup(self, point: Point) -> T | None:
        # Recurse into the child which contains the point
        return self.children[self._index(point)].lookup(point)

Closest points

Now to look at the most interesting operation one can do on a k-d Tree: querying for the objects which are closest to a given point (i.e: the Nearest neighbour search.

This is a more complicated algorithm, which will also need some modifications to current k-d Tree implementation in order to track just a bit more information about the points it contains.

A notion of distance

To search for the closest points to a given origin, we first need to define which distance we are using in our space.

For this example, we’ll simply be using the usual definition of (Euclidean) distance.

def dist(point: Point, other: Point) -> float:
    return sqrt(sum((a - b) ** 2 for a, b in zip(self, other)))

Tracking the tree’s boundaries

To make the query efficient, we’ll need to track the tree’s boundaries: the bounding box of all points contained therein. This will allow us to stop the search early once we’ve found enough points and can be sure that the rest of the tree is too far away to qualify.

For this, let’s define the AABB (Axis-Aligned Bounding Box) class.

class Point(NamedTuple):
    # Convenience function to replace the coordinate along a given dimension
    def replace(self, axis: Axis, new_coord: float) -> Point:
        coords = list(self)
        coords[axis] = new_coord
        return Point(coords)

class AABB(NamedTuple):
    # Lowest coordinates in the box
    low: Point
    # Highest coordinates in the box
    high: Point

    # An empty box
    @classmethod
    def empty(cls) -> AABB:
        return cls(
            Point(*(float("inf"),) * 3),
            Point(*(float("-inf"),) * 3),
        )

    # Split the box into two along a given axis for a given mid-point
    def split(axis: Axis, mid: float) -> tuple[AABB, AABB]:
        assert self.low[axis] <= mid <= self.high[axis]
        return (
            AABB(self.low, self.high.replace(axis, mid)),
            AABB(self.low.replace(axis, mid), self.high),
        )

    # Extend a box to contain a given point
    def extend(self, point: Point) -> None:
        low = NamedTuple(*(map(min, zip(self.low, point))))
        high = NamedTuple(*(map(max, zip(self.high, point))))
        return AABB(low, high)

    # Return the shortest between a given point and the box
    def dist_to_point(self, point: Point) -> float:
        deltas = (
            max(self.low[axis] - point[axis], 0, point[axis] - self.high[axis])
            for axis in Axis
        )
        return dist(Point(0, 0, 0), Point(*deltas))

And do the necessary modifications to the KdTree to store the bounding box and update it as we add new points.

class KdTree[T]:
    _root: KdNode[T]
    # New field: to keep track of the tree's boundaries
    _aabb: AABB

    def __init__(self):
        self._root = KdNode()
        # Initialize the empty tree with an empty bounding box
        self._aabb = AABB.empty()

    def insert(self, point: Point, val: T) -> bool:
        # Extend the AABB for our k-d Tree when adding a point to it
        self._aabb = self._aabb.extend(point)
        return self._root.insert(point, val, Axis.X)

MaxHeap

Python’s builtin heapq module provides the necessary functions to create and interact with a Priority Queue, in the form of a Binary Heap.

Unfortunately, Python’s library maintains a min-heap, which keeps the minimum element at the root. For this algorithm, we’re interested in having a max-heap, with the maximum at the root.

Thankfully, one can just reverse the comparison function for each element to convert between the two. Let’s write a MaxHeap class making use of this library, with a Reverse wrapper class to reverse the order of elements contained within it (similar to Rust’s Reverse).

# Reverses the wrapped value's ordering
@functools.total_ordering
class Reverse[T]:
    value: T

    def __init__(self, value: T):
        self.value = value

    def __lt__(self, other: Reverse[T]) -> bool:
        return self.value > other.value

    def __eq__(self, other: Reverse[T]) -> bool:
        return self.value == other.value

class MaxHeap[T]:
    _heap: list[Reverse[T]]

    def __init__(self):
        self._heap = []

    def __len__(self) -> int:
        return len(self._heap)

    def __iter__(self) -> Iterator[T]:
        yield from (item.value for item in self._heap)

    # Push a value on the heap
    def push(self, value: T) -> None:
        heapq.heappush(self._heap, Reverse(value))

    # Peek at the current maximum value
    def peek(self) -> T:
        return self._heap[0].value

    # Pop and return the highest value
    def pop(self) -> T:
        return heapq.heappop(self._heap).value

    # Pushes a value onto the heap, pops and returns the highest value
    def pushpop(self, value: T) -> None:
        return heapq.heappushpop(self._heap, Reverse(value)).value

The actual Implementation

Now that we have written the necessary building blocks, let’s tackle the Implementation of closest for our k-d Tree.

# Wrapper type for closest points, ordered by `distance`
@dataclasses.dataclass(order=True)
class ClosestPoint[T](NamedTuple):
    point: Point = field(compare=False)
    value: T = field(compare=False)
    distance: float

class KdTree[T]:
    def closest(self, point: Point, n: int = 1) -> list[ClosestPoint[T]]:
        assert n > 0
        # Create the output heap
        res = MaxHeap()
        # Recurse onto the root node
        self._root.closest(point, res, n, self._aabb)
        # Return the resulting list, from closest to farthest
        return sorted(res)

class KdNode[T]:
    def closest(
        self,
        point: Point,
        out: MaxHeap[ClosestPoint[T]],
        n: int,
        bounds: AABB,
    ) -> None:
        # Forward to the wrapped node
        self.inner.closest(point, out, n, bounds)

class KdLeafNode[T]:
    def closest(
        self,
        point: Point,
        out: MaxHeap[ClosestPoint[T]],
        n: int,
        bounds: AABB,
    ) -> None:
        # At the leaf, simply iterate over all points and add them to the heap
        for p, val in self.points.items():
            item = ClosestPoint(p, val, dist(p, point))
            if len(out) < n:
                # If the heap isn't full, just push
                out.push(item)
            elif out.peek().distance > item.distance:
                # Otherwise, push and pop to keep the heap at `n` elements
                out.pushpop(item)

class KdSplitNode[T]:
    def closest(
        self,
        point: Point,
        out: list[ClosestPoint[T]],
        n: int,
        bounds: AABB,
    ) -> None:
        index = self._index(point)
        children_bounds = bounds.split(self.axis, self.mid)
        # Iterate over the child which contains the point, then its neighbour
        for i in (index, 1 - index):
            child, bounds = self.children[i], children_bounds[i]
            # `min_dist` is 0 for the first child, and the minimum distance of
            # all points contained in the second child
            min_dist = bounds.dist_to_point(point)
            # If the heap is at capacity and the child to inspect too far, stop
            if len(out) == n and min_dist > out.peek().distance:
                return
            # Otherwise, recurse
            child.closest(point, out, n, bounds)