Splay 非指针模板(普通平衡树)

踏上平衡树征程的第一步…

前言

Splay 是一种简单且功能丰富的平衡树结构,其算法核心 Splay 操作能维持其均摊复杂度维持在 $O(logn)$ 。

定义

我们将整棵 Splay 定义在结构体中。

并定义结构体node来表示 Splay 的每一个节点。

宏定义e[0].ch[1]为根节点, $1e9 + 10$ 为INF,并在结构体尾取消定义。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct Splay {
#define root (e[0].ch[1])
#define inf (1e9+10)
struct node {
int val; // 当前节点存储的值
int cnt; // 当前节点存储的值出现的次数
int siz; // 当前节点包括其左右子树中包含的数的个数(不是节点数!)
int father; // 当前节点的父亲节点
int ch[2]; // 当前节点的孩子节点,ch[0]表示左孩子,ch[1]表示右孩子
}
// your code goes here...
#undef root
#undef inf
}

基础操作

update()操作用于更新当前节点的siz值;

connect()操作用于连接节点;

identify()操作用于确认当前节点是其父亲的左孩子还是右孩子。

1
2
3
4
5
6
7
8
9
10
void update(int x) {
e[x].siz = e[e[x].ch[0]].siz + e[e[x].ch[1]].siz + e[x].cnt;
}
void connect(int x, int f, int son) {
e[x].father = f;
e[f].ch[son] = x;
}
int identify(int x) {
return x == e[e[x].father].ch[0] ? 0 : 1;
}

旋转

平衡树的必备知识。

rotate(x)表示把x节点上旋到其父亲的位置。

1
2
3
4
5
6
7
8
9
10
void rotate(int x) {
int f = e[x].father, fson = identify(x);
int ff = e[f].father, ffson = identify(f);
int y = e[x].ch[fson ^ 1];
connect(y, f, fson);
connect(f, x, fson ^ 1);
connect(x, ff, ffson);
update(f);
update(x);
}

Splay

Splay 是 Splay 的核心操作。用于把一个节点旋转到指定位置。

需要注意的是,Splay在每次完成查询操作后都要将被查询的节点 Splay 到根。

1
2
3
4
5
6
7
8
9
10
void splay(int at, int to) {
if (!at) return;
to = e[to].father;
while (e[at].father != to) {
int up = e[at].father;
if (e[up].father == to) rotate(at);
else if (identify(at) == identify(up)) rotate(up), rotate(at);
else rotate(at), rotate(at);
}
}

新建 / 擦除节点

为 插入 / 删除 操作提供铺垫。

1
2
3
4
5
6
7
8
9
10
11
12
13
void crepoint(int val, int father) {
int x = ++pos;
e[x].val = val;
e[x].father = father;
e[x].cnt = e[x].siz = 1;
e[x].ch[0] = e[x].ch[1] = 0;
}
void delpoint(int x) {
e[x].val = 0;
e[x].father = 0;
e[x].cnt = e[x].siz = 0;
e[x].ch[0] = e[x].ch[1] = 0;
}

插入

如果是空节点需要特判;

如果可以在树中找到一个值相同的节点那么直接使其cnt++

其余情况根据平衡树的性质找到一个可行位置并插入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void insert(int val) {
int u = root;
points++;
if (points == 1) { // 特判无点状态(看个人写法?)
crepoint(val, 0);
root = pos;
return;
}
while (u) {
e[u].siz++;
if (e[u].val == val) {
e[u].cnt++;
splay(u, root);
return;
}
int son = val < e[u].val ? 0 : 1;
if (!e[u].ch[son]) {
crepoint(val, u);
e[u].ch[son] = pos;
splay(pos, root);
return;
}
u = e[u].ch[son];
}
}

删除

首先将要删除的节点旋转到根节点的位置。

如果要被删除的节点(注意现在它在根的位置)没有左孩子,那么直接摧毁这个节点,并将它的右孩子变成根。

如果自己有左孩子,那么就先把左子树中值最大的元素旋转到根的左孩子位置,然后将根节点的右孩子变成根节点的左孩子的右孩子,然后摧毁节点,并将左孩子变成根。

这样子做是为了使删除节点后的树维持平衡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void erase(int val) {
int u = find(val);
points--;
if (e[u].cnt > 1) {
e[u].cnt--;
e[u].siz--;
return;
}
if (!e[u].ch[0]) {
connect(e[u].ch[1], 0, 1);
root = e[u].ch[1];
} else {
int lft = e[u].ch[0], rit = e[u].ch[1];
while (e[lft].ch[1]) lft = e[lft].ch[1];
splay(lft, e[u].ch[0]);
connect(rit, lft, 1);
connect(lft, 0, 1);
update(lft);
}
delpoint(u);
}

排名

  1. rank()查询 x 数的排名(定义为比当前数小的数的个数 +1 )
  2. atrank()查询排名为 x 的数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
int rank(int val) {
int u = root, ans = 0;
while (u) {
if (val == e[u].val) {
ans += e[e[u].ch[0]].siz + 1;
splay(u, root);
return ans;
}
if (val < e[u].val) u = e[u].ch[0];
else ans += e[e[u].ch[0]].siz + e[u].cnt, u = e[u].ch[1];
}
}
int atrank(int x) {
int u = root;
while (u) {
if (x <= e[e[u].ch[0]].siz) u = e[u].ch[0];
else if (x <= e[e[u].ch[0]].siz + e[u].cnt) {
splay(u, root);
return e[u].val;
} else x -= e[e[u].ch[0]].siz + e[u].cnt, u = e[u].ch[1];
}
}

前驱 & 后继

根据平衡树的性质即可。

需要注意的是相等时仍需要继续查找,那么等号的用法就特别讲究。

此份代码中将其特别突出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int lower(int val) {
int u = root, result = -inf, cho = 0;
while (u) {
if (e[u].val < val && e[u].val > result) result = e[u].val, cho = u;
u = e[u].ch[val <= e[u].val ? 0 : 1];
}
splay(cho, root);
return result;
}
int upper(int val) {
int u = root, result = inf, cho = 0;
while (u) {
if (e[u].val > val && e[u].val < result) result = e[u].val, cho = u;
u = e[u].ch[val >= e[u].val ? 1 : 0];
}
splay(cho, root);
return result;
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// ==============================
// author: memset0
// website: https://memset0.cn
// ==============================
#include <bits/stdc++.h>
#define ll long long
using namespace std;

int read() {
int x = 0; bool m = 0; char c = getchar();
while (!isdigit(c) && c != '-') c = getchar();
if (c == '-') m = 1, c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
if (m) return -x; else return x;
}

const int maxn = 100010;
int n, opt;

struct Splay {
#define root (e[0].ch[1])
#define inf (1e9 + 10)
struct node {
int cnt, siz, val;
int father, ch[2];
} e[maxn];
int pos, points;
void update(int x) {
e[x].siz = e[e[x].ch[0]].siz + e[e[x].ch[1]].siz + e[x].cnt;
}
void connect(int x, int f, int son) {
e[x].father = f;
e[f].ch[son] = x;
}
int identify(int x) {
return x == e[e[x].father].ch[0] ? 0 : 1;
}
void rotate(int x) {
int f = e[x].father, fson = identify(x);
int ff = e[f].father, ffson = identify(f);
int y = e[x].ch[fson ^ 1];
connect(y, f, fson);
connect(f, x, fson ^ 1);
connect(x, ff, ffson);
update(f);
update(x);
}
void splay(int at, int to) {
if (!at) return;
to = e[to].father;
while (e[at].father != to) {
int up = e[at].father;
if (e[up].father == to) rotate(at);
else if (identify(at) == identify(up)) rotate(up), rotate(at);
else rotate(at), rotate(at);
}
}
void crepoint(int val, int father) {
int x = ++pos;
e[x].val = val;
e[x].father = father;
e[x].cnt = e[x].siz = 1;
e[x].ch[0] = e[x].ch[1] = 0;
}
void delpoint(int x) {
e[x].val = 0;
e[x].father = 0;
e[x].cnt = e[x].siz = 0;
e[x].ch[0] = e[x].ch[1] = 0;
}
int find(int val) {
int u = root;
while (u) {
if (val == e[u].val) {
splay(u, root);
return u;
}
u = e[u].ch[val < e[u].val ? 0 : 1];
}
}
void insert(int val) {
int u = root;
points++;
if (points == 1) {
crepoint(val, 0);
root = pos;
return;
}
while (u) {
e[u].siz++;
if (e[u].val == val) {
e[u].cnt++;
splay(u, root);
return;
}
int son = val < e[u].val ? 0 : 1;
if (!e[u].ch[son]) {
crepoint(val, u);
e[u].ch[son] = pos;
splay(pos, root);
return;
}
u = e[u].ch[son];
}
}
void erase(int val) {
int u = find(val);
points--;
if (e[u].cnt > 1) {
e[u].cnt--;
e[u].siz--;
return;
}
if (!e[u].ch[0]) {
connect(e[u].ch[1], 0, 1);
root = e[u].ch[1];
} else {
int lft = e[u].ch[0], rit = e[u].ch[1];
while (e[lft].ch[1]) lft = e[lft].ch[1];
splay(lft, e[u].ch[0]);
connect(rit, lft, 1);
connect(lft, 0, 1);
update(lft);
}
delpoint(u);
}
int rank(int val) {
int u = root, ans = 0;
while (u) {
if (val == e[u].val) {
ans += e[e[u].ch[0]].siz + 1;
splay(u, root);
return ans;
}
if (val < e[u].val) u = e[u].ch[0];
else ans += e[e[u].ch[0]].siz + e[u].cnt, u = e[u].ch[1];
}
}
int atrank(int x) {
int u = root;
while (u) {
if (x <= e[e[u].ch[0]].siz) u = e[u].ch[0];
else if (x <= e[e[u].ch[0]].siz + e[u].cnt) {
splay(u, root);
return e[u].val;
} else x -= e[e[u].ch[0]].siz + e[u].cnt, u = e[u].ch[1];
}
}
int lower(int val) {
int u = root, result = -inf, cho = 0;
while (u) {
if (e[u].val < val && e[u].val > result) result = e[u].val, cho = u;
u = e[u].ch[val <= e[u].val ? 0 : 1];
}
splay(cho, root);
return result;
}
int upper(int val) {
int u = root, result = inf, cho = 0;
while (u) {
if (e[u].val > val && e[u].val < result) result = e[u].val, cho = u;
u = e[u].ch[val >= e[u].val ? 1 : 0];
}
splay(cho, root);
return result;
}
#undef root
#undef inf
} s;

int main() {

n = read();
while (n--) {
opt = read();
switch(opt) {
case 1:
s.insert(read());
break;
case 2:
s.erase(read());
break;
case 3:
printf("%d\n", s.rank(read()));
break;
case 4:
printf("%d\n", s.atrank(read()));
break;
case 5:
printf("%d\n", s.lower(read()));
break;
case 6:
printf("%d\n", s.upper(read()));
break;
}
}

return 0;
}

备注 & 参考资料

本文只是模板向的 Splay 教程,请在理解 Splay 后查看。

题目链接:

参考资料:

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×