参考资料:
斯坦福CS106L
算法目的:
在一个N维空间中找到某个点最近的一个点(距离用平方来衡量)。
算法:
算法的大致过程如下(记住还有一部分不一样),跟二叉搜索树差不多,比较的树跟每个节点进行比较来预测要向左还是向右,但是这里有N维,每一层从根节点开始一次是第0,1,2...N-1,0,1,2...个数字进行比较,图中粗体的数字就是那个节点要比较的数字,此数字要比左子树对应的数字(序号相同)大,比右子树对应的数字小或者相等。插入的时候也是一样的。
直觉:
以二叉树为例,每次把数据切分两半进行查找。
线性切分二维数据,随机选择一个数组,随机做一条直线,把数据且成两部分,不断地分割下去( binary space partitioning trees or BSP trees)。多维也是一样,超平面。
KDTree是上面的一种特殊情况,首先根节点是竖直切分平面,多次之后如下所示。
但是这样分并不能我们保证我们可以得到最近的点,如下图所示,绿色的点是目标点,如上进行切分,得到最近的点是跟其在同一个区域的灰色点,但是根据距离公式,蓝色的点才是最近的点。
但在另外一种情况下,我们可以确定跟绿点最近的点一定是根据上面规则划分的同一个区域内。我们在缩小范围的时候(假设是第二次画的横线),每次都会更新最小的距离r和点,我们知道距离更小的点只能是在目标点及r所形成的圆内,而如果我们知道圆又在我们缩小的这个区域里面(这个计算d = |b(i) – a(i)| > r 即可知道),那么横线所切的另外一个方向是不能有更近的点的(局部与整体的关系,圆外(整体)都不可能,更不用说横线之上(局部))。这里我们切了两刀(图中序号标出),假设我们知道最小的距离至少是r,切了两刀,但实际区域并没有大概减少至1/4,因为第一道切完,距离一算,发现包含圆的一部分,因此并不说这一刀的左边的黄色点可以排除,第二刀如上所述可以排除蓝色的点。
伪代码如下,除了最后两句,上面的都是在二分区域,最后两句是为了补救那些不可排除的区域。其中bestDist可以是在上面递归里面减少,不必保持不变。
进一步推广——寻找K个最近的点(k-NN search)
bounded priority queue (or BPQ for short):固定容量的队列,元素有一定的优先级,如果队列满了,新的元素优先级更高,则会有元素会被弹出,并塞入新元素,新的元素优先级不够高,则队列保持原状。
实现步骤:
1、实现KDTree的基本操作
主要是插入,规则跟二叉搜索树差不多。其他也只是数的一些相关操作。
2、实现kNN,主要是在lambda函数recursiveFun部分,这部分是上一张图片描述的算法实现的。
要注意一点是Lambda函数的回调调用,首先要提前声明,然后赋值的时候要把函数名也作为captures。
// ElemType kNNValue(const Point<N>& key, size_t k) const
// Usage: cout << kd.kNNValue(v, 3) << endl;
// ----------------------------------------------------
// Given a point v and an integer k, finds the k points in the KDTree
// nearest to v and returns the most common value associated with those
// points. In the event of a tie, one of the most frequent value will be
// chosen.
ElemType kNNValue(const Point<N>& key, size_t k) const {
BoundedPQueue<ElemType> q(k);
double distance;
//定义递归函数
std::function<void(Node*)> recursiveFun;
recursiveFun = [&recursiveFun,&q,&distance,&key](Node *node) {
if(node==nullptr)
return ;
distance = 0;
const Point<N>& currentKey = node->key;
for(size_t i=0;i<N;i++) {
distance += (key[i] - currentKey[i])*(key[i] - currentKey[i]);
}
q.enqueue(node->value,distance);
Node* other;
size_t index = (node->level) % N;
if(key[index] < node->key[index]){
recursiveFun(node->lc);
other = node->rc;
} else {
recursiveFun(node->rc);
other = node->lc;
}
if((!q.full()) || abs(key[index] - node->key[index]) < q.worst())
recursiveFun(other);
};
//开始运行
recursiveFun(root);
// 统计每个ElemType类型变量出现频次
map<ElemType,int> cntMap;
ElemType elem;
while(!q.empty()){
elem = q.dequeueMin();
if(cntMap.find(elem)==cntMap.end()){
cntMap[elem] = 1;
} else {
cntMap[elem] += 1;
}
}
//选出频次最大的
ElemType maxElem;
int Max = 0;
for(auto pair:cntMap) {
if(pair.second > Max){
maxElem = pair.first;
Max = pair.second;
}
}
return maxElem;
}
3、复制构造器和赋值符号
这里我一开始一次性分配全部内存可能会快一点。
Node *copyNodes(Node *buff,Node *node){
if(node == nullptr)
return nullptr;
Node *NodeCopyed = buff + _cnt;
_cnt ++;
NodeCopyed->lc = copyNodes(buff,node->lc);
NodeCopyed->rc = copyNodes(buff,node->rc);
NodeCopyed->key = node->key;
NodeCopyed->level = node->level;
NodeCopyed->value = node->value;
return NodeCopyed;
};
// KDTree(const KDTree& rhs);
// KDTree& operator=(const KDTree& rhs);
// Usage: KDTree<3, int> one = two;
// Usage: one = two;
// -----------------------------------------------------
// Deep-copies the contents of another KDTree into this one.
KDTree(const KDTree& rhs) {
n_elems = rhs.n_elems;
Node *buff = new Node[rhs.n_elems];
_cnt = 0;
root = copyNodes(buff,rhs.root);
}
KDTree& operator=(const KDTree& rhs){
n_elems = rhs.n_elems;
Node *buff = new Node[rhs.n_elems];
_cnt = 0;
root = copyNodes(buff,rhs.root);
return *this;
}