stMind

about Tech, Computer vision and Machine learning

kd treeとNearest neighbor search

今年のコード書き初め。kdtreeとNearest neighbor searchをrubyで。

kd treeとNearest neighbor searchは、多次元空間における点集合Pが与えられた時、
点xに最も近いのはPのどの点か?を求めるときに使う。

kd treeは、探索空間を次元毎に再帰的に2分割した木構造とする方法。

def kdtree_build(pointList, depth=0)
  if pointList.size == 0 then
    return nil
  end
 
  dim = pointList[0].size
  axis = depth % dim
  sortedList = pointList.sort{|p,q| p[axis] <=> q[axis]}
  median = sortedList.size / 2
  
  node = Node.new
  node.location = sortedList[median]
  node.left_child = kdtree_build(sortedList[0, median], depth+1)
  node.right_child = kdtree_build(sortedList[median+1..-1], depth+1)
  return node
end

Nearest neighbor searchは、各ノードと次元毎に大小比較して木構造を葉ノードまで辿ったら、親ノードの
もう一つの領域のノードにより近いものがないかをチェックする。ルートノードまで処理し終えたら探索終了。

def kdtree_search(qt, node, depth=0)
  if node.nil? then
    return nil
  end

  # if we have reached a leaf
  if node.is_leaf then
    return node.location
  end

  # this node is no leaf

  dim = node.location.size
  axis = depth % dim

  # compare query point and point of current node in selected dimension
  if qt[axis] < node.location[axis] then

    if node.left_child then
      if node.left_child.is_leaf then
        leaf = node.left_child.location
      else
        leaf = kdtree_search(qt, node.left_child, depth+1)
      end
    else
      leaf = kdtree_search(qt, node.right_child, depth+1)
    end

    if node.left_child && node.right_child && 
        square_distance(qt, leaf) > (qt[axis]-node.location[axis])**2 then
      leaf = closer(qt, leaf, kdtree_search(qt, node.right_child, depth+1))
    end
                    
  else
    if node.right_child then
      if node.right_child.is_leaf then
        leaf = node.right_child.location
      else
        leaf = kdtree_search(qt, node.right_child, depth+1)
      end
    else
      leaf = kdtree_search(qt, node.left_child, depth+1)
    end

    if node.right_child && node.left_child &&
        square_distance(qt, leaf) > (qt[axis]-node.location[axis])**2 then
      leaf = closer(qt, leaf, kdtree_search(qt, node.left_child, depth+1))
    end
    
  end

  return closer(qt, leaf, node.location)

3次元のデータで試した。

if __FILE__ == $0
  list = [[1,2,3],[4,0,1],[5,3,1],[10,5,4],[9,8,9],[4,2,4]]
  tree = kdtree_build(list)
  show_tree(tree)

  target = [1,1,1]
  result = kdtree_search(target, tree)
  print "target:", target, ", nearest:", result, "\n"

イマイチしっくりきてないけれど、一応動いているよう。

$ ruby kdtree.rb 
===tree===
rt, 0:[5, 3, 1]
lc, 1:[1, 2, 3]
lc, 2:[4, 0, 1]
rc, 2:[4, 2, 4]
rc, 1:[9, 8, 9]
lc, 2:[10, 5, 4]
==========
comp: [4, 0, 1] and [4, 2, 4]
comp: [4, 0, 1] and [1, 2, 3]
comp: [1, 2, 3] and [5, 3, 1]
target:[1, 1, 1], nearest:[1, 2, 3]