一、原理

通过伸展操作Splay(x),将每一次查询的节点旋转到根部,以让BST尽可能平衡

二、伸展操作

  • LL型:两次右旋,先x->prt,后x
  • RR型:两次左旋,先x->prt,后x
  • LR型:先左旋,后右旋,均旋转x
  • RL型:先右旋,后左旋,均旋转x

三、删除

  • 如果x为叶节点,直接删除

  • 如果x只有一个子节点,直接删除

  • 如果x有两个子节点,在x的左子树中找到最大的点y,将y伸展到根节点,然后将x的右子树接在y的右子树上

四、代码实现

4.1 伸展

注意更新指针指向其新的父节点

SplayTree* RotateLeft(SplayTree* y) {
SplayTree* x = y->rch, *z = y->prt;
if (z != NULL) {
if (z->lch == y)z->lch = x;
else z->rch = x;
}
y->rch = x->lch; if(x->lch!=NULL)x->lch->prt = y;
x->lch = y; y->prt = x;
x->prt = z;

Update(y);
Update(x);
return x;
}
SplayTree* RotateRight(SplayTree* y) {
SplayTree* x = y->lch, * z = y->prt;
if (z != NULL) {
if (z->lch == y)z->lch = x;
else z->rch = x;
}
y->lch = x->rch; if(x->rch!=NULL)x->rch->prt = y;
x->rch = y; y->prt = x;
x->prt = z;

Update(y);
Update(x);
return x;
}
SplayTree* Splay(SplayTree* x) {
if (x == NULL)return NULL;
while (x->prt != NULL) {//每一步需要找x->prt->prt,x->prt,x
SplayTree* y = x->prt, * z = x->prt->prt;
if (z == NULL) {
if (y->lch == x)x = RotateRight(x->prt);
else x = RotateLeft(x->prt);
}
else {
if (z->lch == y && y->lch == x) {//LL型,右旋,先旋转父节点,再旋转x
x->prt = RotateRight(x->prt->prt);
x = RotateRight(x->prt);
}
else if (z->rch == y && y->rch == x) {//RR型,左旋,先旋转父节点,再旋转x
x->prt = RotateLeft(x->prt->prt);
x = RotateLeft(x->prt);
}
else if (z->lch == y && y->rch == x) {//LR型,左旋、右旋
x = RotateLeft(x->prt);
x = RotateRight(x->prt);
}
else {//RL型,右旋、左旋
x = RotateRight(x->prt);
x = RotateLeft(x->prt);
}
}
}
return x;
}

4.2 插入 = BST + Splay(x)

SplayTree* Insert(SplayTree* root, SplayTree* prt, int key) {
if (root == NULL) {
root = new SplayTree();
root->data = key;
root->lch = NULL;
root->rch = NULL;
root->prt = prt;
root->size = 1;
root->sum = 1;
now = root;
return root;
}
else if (key == root->data) {
root->sum++;
root->size++;
return root;
}
else if (key < root->data) {//插入到左子树
root->lch = Insert(root->lch, root, key);
Update(root);
}
else {//插入到右子树
root->rch = Insert(root->rch, root, key);
Update(root);
}
return root;
}

4.3 删除

SplayTree* Merge(SplayTree* left, SplayTree* right) {
if(left==NULL)return right;
if(right==NULL)return left;
SplayTree* temp = left;
while (temp->rch != NULL)temp = temp->rch;//找到左树中最大的数
temp = Splay(temp);//将左树中的最大值旋转到根节点,此时根节点没有右儿子
temp->rch = right;//直接将右树接在左树的右儿子处即可
right->prt = temp;
Update(temp);
return temp;
}
SplayTree* Delte(SplayTree* root, int key) {
if (root == NULL) return NULL;
while (root != NULL) {
if (key < root->data)root = root->lch;
else if (key > root->data)root = root->rch;
else {//找到该值,进行删除操作
root = Splay(root);//将该数旋转到根节点
root->sum--;
root->size--;
if (root->sum > 0)return root;
//根节点需要被删除,只需要将根节点的左右儿子合并即可
SplayTree* left = root->lch, *right = root->rch;
if(left!=NULL)left->prt = NULL;
if(right!=NULL)right->prt = NULL;
delete root;
root = Merge(left, right);
return root;
}
}
return NULL;
}

4.4 查询某数是第几大 = BST + Splay(x)

//在main函数调用时,使用Splay(x)
int GetRank(SplayTree* root, int key) {
if (root == NULL)return 0;
if (key == root->data)//key == 当前节点的值
return GetSize(root->lch) + 1;
else if (key > root->data)//key > 当前节点的值 ==> 在右子树里面
return GetSize(root->lch) + root->sum + GetRank(root->rch, key);
else //key < 当前节点的值 ==> 在左子树里面
return GetRank(root->lch, key);
}

4.5 第k大数 = BST + Splay(x)

//在main函数调用时,使用Splay(x)
SplayTree* FindByRank(SplayTree* root, int rank) {
if (rank > GetSize(root->lch) + root->sum)//比左子树+自己大 ==> 在右子树里
return FindByRank(root->rch, rank - GetSize(root->lch) - root->sum);
else if (rank > GetSize(root->lch))//比左子树大 ==> 就是自己
return root;
else//比左子树小 ==> 在左子树里
return FindByRank(root->lch, rank);
}

4.6 前驱 = BST + Splay(x)

//在main函数调用时,使用Splay(x)
SplayTree* FindPre(SplayTree* root, int key) {//找比key小的最大的数
SplayTree* ans = NULL;
while (root != NULL) {
if (key <= root->data) {//root>key,在左子树,注意等号成立时依旧需要向下找
root = root->lch;
}
else {//root<key,在右子树
if (ans == NULL || ans->data < root->data)ans = root;
root = root->rch;
}
}
return ans;
}

4.7 后继 = BST + Splay(x)

//在main函数调用时,使用Splay(x)
SplayTree* FindNext(SplayTree* root, int key) {//找比key大的最小的数
SplayTree* ans = NULL;
while (root != NULL) {
if (key < root->data) {//root>key,在左子树
if (ans == NULL || ans->data > root->data)ans = root;
root = root->lch;
}
else {//root<key,在右子树,注意等号成立时依旧需要向下找
root = root->rch;
}
}
return ans;
}

五、总代码

#pragma warning(disable:4996)
//Splay树
#include <iostream>
#include <fstream>
#include <algorithm>
using namespace std;
struct SplayTree {
int data;
int size, sum;
SplayTree* prt, * lch, * rch;
};
SplayTree* root, * now;
int GetSize(SplayTree*);
void Update(SplayTree*);
SplayTree* RotateLeft(SplayTree*);
SplayTree* RotateRight(SplayTree*);
SplayTree* Merge(SplayTree*, SplayTree*);
SplayTree* Splay(SplayTree*);


SplayTree* Insert(SplayTree*, SplayTree*, int);
SplayTree* Delte(SplayTree*, int);
int GetRank(SplayTree*, int);
SplayTree* FindByRank(SplayTree*, int);
SplayTree* FindPre(SplayTree*, int);
SplayTree* FindNext(SplayTree*, int);

int main() {
//freopen("1.in", "r", stdin);
int n = 0, op, x;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d%d", &op, &x);
SplayTree* temp = NULL;
if (op == 1) {
root = Insert(root, NULL, x);
root = Splay(now);
}
else if (op == 2) {
root = Delte(root, x);
}
else if (op == 3) {
printf("%d\n", GetRank(root, x));
}
else if (op == 4) {
temp = FindByRank(root, x);
root = Splay(temp);
}
else if (op == 5) {
temp = FindPre(root, x);
root = Splay(temp);
}
else if (op == 6) {
temp = FindNext(root, x);
root = Splay(temp);
}
if (op == 4 || op == 5 || op == 6)
printf("%d\n", root->data);
}
return 0;
}
int GetSize(SplayTree* root) {
if (root == NULL)return 0;
return root->size;
}
void Update(SplayTree* root) {
if (root == NULL)return;
root->size = GetSize(root->lch) + GetSize(root->rch) + root->sum;
}
SplayTree* RotateLeft(SplayTree* y) {
SplayTree* x = y->rch, * z = y->prt;
if (z != NULL) {
if (z->lch == y)z->lch = x;
else z->rch = x;
}
y->rch = x->lch; if (x->lch != NULL)x->lch->prt = y;
x->lch = y; y->prt = x;
x->prt = z;

Update(y);
Update(x);
return x;
}
SplayTree* RotateRight(SplayTree* y) {
SplayTree* x = y->lch, * z = y->prt;
if (z != NULL) {
if (z->lch == y)z->lch = x;
else z->rch = x;
}
y->lch = x->rch; if (x->rch != NULL)x->rch->prt = y;
x->rch = y; y->prt = x;
x->prt = z;

Update(y);
Update(x);
return x;
}
SplayTree* Splay(SplayTree* x) {
if (x == NULL)return NULL;
while (x->prt != NULL) {//每一步需要找x->prt->prt,x->prt,x
SplayTree* y = x->prt, * z = x->prt->prt;
if (z == NULL) {
if (y->lch == x)x = RotateRight(x->prt);
else x = RotateLeft(x->prt);
}
else {
if (z->lch == y && y->lch == x) {//LL型,右旋,先旋转父节点,再旋转x
x->prt = RotateRight(x->prt->prt);
x = RotateRight(x->prt);
}
else if (z->rch == y && y->rch == x) {//RR型,左旋,先旋转父节点,再旋转x
x->prt = RotateLeft(x->prt->prt);
x = RotateLeft(x->prt);
}
else if (z->lch == y && y->rch == x) {//LR型,左旋、右旋
x = RotateLeft(x->prt);
x = RotateRight(x->prt);
}
else {//RL型,右旋、左旋
x = RotateRight(x->prt);
x = RotateLeft(x->prt);
}
}
}
return x;
}
SplayTree* Merge(SplayTree* left, SplayTree* right) {
if (left == NULL)return right;
if (right == NULL)return left;
SplayTree* temp = left;
while (temp->rch != NULL)temp = temp->rch;//找到左树中最大的数
temp = Splay(temp);
temp->rch = right;
right->prt = temp;
Update(temp);
return temp;
}

SplayTree* Insert(SplayTree* root, SplayTree* prt, int key) {
if (root == NULL) {
root = new SplayTree();
root->data = key;
root->lch = NULL;
root->rch = NULL;
root->prt = prt;
root->size = 1;
root->sum = 1;
now = root;
return root;
}
else if (key == root->data) {
root->sum++;
root->size++;
return root;
}
else if (key < root->data) {//插入到左子树
root->lch = Insert(root->lch, root, key);
Update(root);
}
else {//插入到右子树
root->rch = Insert(root->rch, root, key);
Update(root);
}
return root;
}
SplayTree* Delte(SplayTree* root, int key) {
if (root == NULL) return NULL;
while (root != NULL) {
if (key < root->data)root = root->lch;
else if (key > root->data)root = root->rch;
else {//找到该值,进行删除操作
root = Splay(root);
root->sum--;
root->size--;
if (root->sum > 0)return root;
//根节点需要被删除,只需要将根节点的左右儿子合并即可
SplayTree* left = root->lch, * right = root->rch;
if (left != NULL)left->prt = NULL;
if (right != NULL)right->prt = NULL;
delete root;
root = Merge(left, right);
return root;
}
}
return NULL;
}
int GetRank(SplayTree* root, int key) {
if (root == NULL)return 0;
if (key == root->data)//key == 当前节点的值
return GetSize(root->lch) + 1;
else if (key > root->data)//key > 当前节点的值 ==> 在右子树里面
return GetSize(root->lch) + root->sum + GetRank(root->rch, key);
else //key < 当前节点的值 ==> 在左子树里面
return GetRank(root->lch, key);
}
SplayTree* FindByRank(SplayTree* root, int rank) {
if (rank > GetSize(root->lch) + root->sum)//比左子树+自己大 ==> 在右子树里
return FindByRank(root->rch, rank - GetSize(root->lch) - root->sum);
else if (rank > GetSize(root->lch))//比左子树大 ==> 就是自己
return root;
else//比左子树小 ==> 在左子树里
return FindByRank(root->lch, rank);
}
SplayTree* FindPre(SplayTree* root, int key) {//找比key小的最大的数
SplayTree* ans = NULL;
while (root != NULL) {
if (key <= root->data) {//root>key,在左子树
root = root->lch;
}
else {//root<key,在右子树
if (ans == NULL || ans->data < root->data)ans = root;
root = root->rch;
}
}
return ans;
}
SplayTree* FindNext(SplayTree* root, int key) {//找比key大的最小的数
SplayTree* ans = NULL;
while (root != NULL) {
if (key < root->data) {//root>key,在左子树
if (ans == NULL || ans->data > root->data)ans = root;
root = root->lch;
}
else {//root<key,在右子树
root = root->rch;
}
}
return ans;
}