K-D Tree

“纵深浩瀚与无穷,于深空之中窥视世界。”

K-D Tree

全称为 ,用于高效处理多维空间信息的数据结构,当 的时候, 的优势极其明显。

其空间复杂度为线性 ,而单次操作的时间复杂度为

实现

首先, 是一棵 ,二叉搜索树,我们熟知的平衡树就是二叉搜索树。

建树

我们考虑像线段树或者平衡树那样的操作——递归建树,即实现一个函数 build(int l,int r) ,表示一棵由 维结点 组成的一棵 子树。

有以下操作:

  • l==r 时,直接将这个结点复制下来。
  • 选择一个维度 ,并选择一个结点
  • 内第 维小于 的结点划分至 的左子树,其余的划分至右子树。
  • 递归。

那么,使得这棵树平衡的关键就在于我们选择的维度 和划分结点 是否正确了。

首先是维度的选择,我们要求当前选择的维度 中结点分布的差异度最大,对于划分结点而言,我们则要求左右子树的大小差最小。

这里有几种优化。

方差优化(维度)

前面我们已经提到,我们要求差异度最大,而能够体现出差异的那就是方差,所以我们考虑将 结点在 维的方差全部算出来,然后以方差最大的维度划分。


轮换优化(维度)

即设当前我们选择的是第 维,那下一次我们就选择 维来划分。这个有随机的成分在里面,不太保险。


中位数优化(结点)

对于求取划分结点,我们选择第 维的中位数所在结点作为当前子树的根结点。


分析

容易发现,当我们使用了中位数优化之后,整个 的树高被限制在了 内,那整个划分操作一共只有 次。但我们考虑求中位数的时间复杂度,如果使用 std::sort 的话,时间复杂度会升高为 ,这是极其不彳亍的。

但实际上,求中位数的时间复杂度是线性 的,这要求我们明晰 std::sort 的实现过程,这里不明讲,感兴趣的读者可以自行了解,《进阶指北》里也提到过,还好 STL 已经为我们实现了这个功能:调用 <algorithm> 里的 std::nth_element() 即可,使用方式与 std::sort 一致。

关于 std::nth_element()

保证时间复杂度为线性,空间复杂度为你传入的数组范围。

需要注意的是,当你传入区间 之后并调用,它只会保证中间那个一定是该段的中位数,且左边比中位数小右边比中位数大,其余部分是不保证有序的。(且很可能是乱序)所以调用之后的编号顺序与调用之前多半不一样。

这样的话,我们保证了建树操作的时间复杂度为


插入

的操作是在线的,这点优于

因为一个结点会包含 维信息(以及其他信息),所以我们传参一般传编号。实现 insert(int &rt,int id) 函数,其中 rt 表示当前根结点的编号。

我们会记录一个数组 表示当前结点的划分维度,那我们按照建树那样向下直到叶结点即可。


区间查询

这里的区间被重定义为以 为根的子树所包含的一个超长方体内。

对于每一个结点,维护一个 表示其在第 维的值域,这是可以通过 pushup 实现的。然后类似于线段树的查询,我们递归至查询区间即可。

二维单次查询的时间复杂度范围是 ,扩展到 维均摊为

KDT的扩展

容易发现, 的性质与 有很大部分类似,且用法也很类似,这给了 很大的扩展。也正是这样, 也支持区间操作。


删除 / 重构

的删除极其麻烦,所以干脆就不删除了。

我们定义一个阈值 ,如果某一个值(比如删除权值)超过了 ,我们就进行重构,一般而言, 左右(这里是删除结点与全结点的比例)。

一般重构

实现 rebuild(int &rt) ,将以 为根结点的子树压缩成序列,并遍历整棵子树,将结点编号序列存储到一个数组 中,然后执行 build(1,tot) ,这里的 是剩余结点的个数。


删除重构

每一次删除结点,就按照插入那样找到那个结点,然后打上标记 ,然后当删除权值超过了设定阈值,则进行重构,忽略打上标记的结点即可。


例题

P4148 简单题

带修在线二维区间权值和,是 的板子。 卡掉了二维树状数组,强制在线卡掉了 的空间卡掉了树套树,那就只能用 来解决了。

考虑熟悉操作。不得不说,码量还是有些大。这是替罪羊树式暴力重构 ,还有线段树式和其他方式的,之后提及。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// ----- Eternally question-----
// Problem: P4148 简单题
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4148
// Memory Limit: 20 MB
// Time Limit: 8000 ms
// Written by: Eternity
// Time: 2022-12-16 17:26:30
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ std::cout<<x; }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
#define ls(p) Tr[p].lc
#define rs(p) Tr[p].rc
const int MAXN=1e5+10,K=2;
const double alpha=0.75;
struct KDT
{
int lc,rc,sum,val,siz,l[K],r[K],Dim[K];
}Tr[MAXN];
int N,lastans,Rt,Len;
int Mdi[K],Qry[K][2],Val;
int Idx,Num,g[MAXN];
inline bool cmp(int a,int b){return Tr[a].Dim[Idx]<Tr[b].Dim[Idx]; }
inline void pushUp(int p)
{
Tr[p].siz=Tr[ls(p)].siz+Tr[rs(p)].siz+1;
Tr[p].sum=Tr[ls(p)].sum+Tr[rs(p)].sum+Tr[p].val;
for(int k=0;k<K;++k)
{
if(ls(p)) checkMax(Tr[p].r[k],Tr[ls(p)].r[k]),checkMin(Tr[p].l[k],Tr[ls(p)].l[k]);
if(rs(p)) checkMax(Tr[p].r[k],Tr[rs(p)].r[k]),checkMin(Tr[p].l[k],Tr[rs(p)].l[k]);
}
}
void build(int &rt,int l,int r,int k)
{
if(l>r) return ;
int mid=(l+r)>>1;Idx=k;
std::nth_element(g+l,g+mid+1,g+r+1,cmp);
rt=g[mid];
Tr[rt].sum=Tr[rt].val;
for(int i=0;i<K;++i) Tr[rt].l[i]=Tr[rt].r[i]=Tr[rt].Dim[i];
build(ls(rt),l,mid-1,(k+1)%K),
build(rs(rt),mid+1,r,(k+1)%K);
pushUp(rt);
}
void erase(int &x)
{
if(!x) return ;
g[++Num]=x;
erase(ls(x)),erase(rs(x));
x=0;
}
inline void rebuild(int &x,int d)
{
g[Num=1]=++Len;
Tr[Len].siz=1;
for(int k=0;k<K;++k) Tr[Len].Dim[k]=Mdi[k];
Tr[Len].val=Tr[Len].sum=Val;
erase(x),build(x,1,Num,d);
}
void insert(int &x,int k)
{
if(!x)
{
Tr[x=++Len].siz=1,Tr[x].val=Tr[x].sum=Val;
for(int i=0;i<K;++i) Tr[x].r[i]=Tr[x].l[i]=Tr[x].Dim[i]=Mdi[i];
return ;
}
if(Mdi[k]<Tr[x].Dim[k])
{
if(Tr[ls(x)].siz>Tr[x].siz*alpha) rebuild(x,k);
else insert(ls(x),(k+1)%K);
}
else
{
if(Tr[rs(x)].siz>Tr[x].siz*alpha) rebuild(x,k);
else insert(rs(x),(k+1)%K);
}
pushUp(x);
}
inline bool checkRange(int x)
{
if(!x) return 0;
for(int k=0;k<K;++k) if(Tr[x].l[k]<Qry[k][0]||Tr[x].r[k]>Qry[k][1]) return 0;
return 1;
}
inline bool checkPoint(int x)
{
if(!x) return 0;
for(int k=0;k<K;++k) if(Tr[x].Dim[k]<Qry[k][0]||Tr[x].Dim[k]>Qry[k][1]) return 0;
return 1;
}
inline bool check(int x)
{
if(!x) return 0;
for(int k=0;k<K;++k) if(Qry[k][1]<Tr[x].l[k]||Tr[x].r[k]<Qry[k][0]) return 0;
return 1;
}
int query(int x)
{
if(checkRange(x)) return Tr[x].sum;
int res=0;
if(checkPoint(x)) res+=Tr[x].val;
if(check(ls(x))) res+=query(ls(x));
if(check(rs(x))) res+=query(rs(x));
return res;
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N);
while(true)
{
int opt;read(opt);
if(opt==3) break;
if(opt==1)
{
for(int k=0;k<K;++k) read(Mdi[k]),Mdi[k]^=lastans;
read(Val),Val^=lastans;
insert(Rt,0);
}
else
{
for(int i=0;i<=1;++i)
for(int k=0;k<K;++k)
read(Qry[k][i]),Qry[k][i]^=lastans;
write(lastans=query(Rt),'\n');
}
}
return 0;
}
/*

*/

P1429 平面最近点对(加强版)

并不是这道题的正解做法,因为 本质上还是 的剪枝。因此 过不了加强加强版,这个时候就需要充分发挥人类智慧了。

枚举每个结点,对于每个结点找到不等于该结点且距离最小的点,即可求出答案。每次暴力遍历 上的每个结点的时间复杂度是 的,需要剪枝。我们可以维护一个子树中的所有结点在每一维上的坐标的最小值和最大值。假设当前已经找到的最近点对的距离是 ,如果查询点到子树内所有点都包含在内的长方形的最近距离大于等于 ,则在这个子树内一定没有答案,搜索时不进入这个子树。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// ----- Eternally question-----
// Problem: P1429 平面最近点对(加强版)
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P1429
// Memory Limit: 256 MB
// Time Limit: 1000 ms
// Written by: Eternity
// Time: 2022-12-19 09:54:53
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<>
inline void read(double &x){ scanf("%lf",&x); }
template<>
inline void read(float &x){ scanf("%f",&x); }
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ std::cout<<x; }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
#define ls(p) Tr[p].lc
#define rs(p) Tr[p].rc
const int MAXN=2e5+10;
const ll INF=1e18;
int N,Ct[MAXN];
double ans=INF;
struct KDT
{
int lc,rc;
double x,y,l,r,d,u;
}Tr[MAXN];
inline double getDist(int i,int j){ return (Tr[i].x-Tr[j].x)*(Tr[i].x-Tr[j].x)+(Tr[i].y-Tr[j].y)*(Tr[i].y-Tr[j].y); }
inline bool cmpX(KDT a,KDT b){ return a.x<b.x; }
inline bool cmpY(KDT a,KDT b){ return a.y<b.y; }
inline void pushUp(int p)
{
Tr[p].l=Tr[p].r=Tr[p].x;
Tr[p].u=Tr[p].d=Tr[p].y;
if(ls(p)) checkMin(Tr[p].l,Tr[ls(p)].l),checkMax(Tr[p].r,Tr[ls(p)].r),checkMin(Tr[p].d,Tr[ls(p)].d),checkMax(Tr[p].u,Tr[ls(p)].u);
if(rs(p)) checkMin(Tr[p].l,Tr[rs(p)].l),checkMax(Tr[p].r,Tr[rs(p)].r),checkMin(Tr[p].d,Tr[rs(p)].d),checkMax(Tr[p].u,Tr[rs(p)].u);
}
int build(int l,int r)
{
if(l>r) return 0;
if(l==r) return pushUp(l),l;
int mid=(l+r)>>1;
double avx=0,avy=0,svx=0,svy=0;
for(int i=l;i<=r;++i) avx+=Tr[i].l,avy+=Tr[i].r;
avx/=1.0*(r-l+1),avy/=1.0*(r-l+1);
for(int i=l;i<=r;++i) svx+=(Tr[i].x-avx)*(Tr[i].x-avx),svy+=(Tr[i].y-avy)*(Tr[i].y-avy);
if(svx>=svy) Ct[mid]=1,std::nth_element(Tr+l,Tr+mid,Tr+r+1,cmpX);
else Ct[mid]=2,std::nth_element(Tr+l,Tr+mid,Tr+r+1,cmpY);
ls(mid)=build(l,mid-1),rs(mid)=build(mid+1,r);
return pushUp(mid),mid;
}
inline double otDist(int s,int id)
{
double ret=0;
if(Tr[id].l>Tr[s].x) ret+=(Tr[id].l-Tr[s].x)*(Tr[id].l-Tr[s].x);
if(Tr[id].r<Tr[s].x) ret+=(Tr[id].r-Tr[s].x)*(Tr[id].r-Tr[s].x);
if(Tr[id].d>Tr[s].y) ret+=(Tr[id].d-Tr[s].y)*(Tr[id].d-Tr[s].y);
if(Tr[id].u<Tr[s].y) ret+=(Tr[id].u-Tr[s].y)*(Tr[id].u-Tr[s].y);
return ret;
}
void query(int l,int r,int x)
{
if(l>r) return ;
int mid=(l+r)>>1;
if(mid!=x) checkMin(ans,getDist(x,mid));
if(l==r) return ;
double dl=otDist(x,ls(mid)),dr=otDist(x,rs(mid));
if(dl<ans&&dr<ans)
{
if(dl<dr)
{
query(l,mid-1,x);
if(dr<ans) query(mid+1,r,x);
}
else
{
query(mid+1,r,x);
if(dl<ans) query(l,mid-1,x);
}
}
else
{
if(dl<ans) query(l,mid-1,x);
if(dr<ans) query(mid+1,r,x);
}
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N);
for(int i=1;i<=N;++i) read(Tr[i].x,Tr[i].y);
build(1,N);
for(int i=1;i<=N;++i) query(1,N,i);
printf("%.4lf",std::sqrt(ans));
return 0;
}
/*

*/

参考文献