并查集 + Tarjan 算法

并查集是一种用于找出一个森林(图)中树(连通分支)的个数的算法,也可用于判断两个节点是否在同一棵树上。它在每一棵树(连通分支)上选择一个节点作为本棵树(连通分支)的代表。对于给定两个节点,如果他们具有相同的代表节点,则说明两个节点在同一个节点上。

一、并查集的简单应用

例题 1:城市群的数量

题目描述:

魔法大陆上有 n 个城市,编号为 1 到 n。城市与城市之间的道路均为双向道路,共有 m 条双向道路,并非任意两个城市之间都有双向道路。问,魔法大陆上有多少个城市群?
若两个城市之间存在一条双向道路,则两个城市属于同一个城市群。任意两个城市之间最多只有一条双向道路。

输入格式:

第一行包含两个整数 n,m,含义与问题描述中相同。接下来 m 行,每行包含两个整数 u,v,表示城市 u 和城市 v 之间存在一条双向道路。

输出格式:

输出共一行,包含一个整数,表示城市群的数量。

代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# DFS 暴力做法
from sys import setrecursionlimit
setrecursionlimit(1000000)
def dfs(cur):
# traversal from the city whose index is cur
global n,g,visited
visited[cur] = True
# search next city
for i in g[cur]:
if visited[i]:
continue
dfs(i)

def counter():
# traversal all if the city,everytime it uses the dfs() function,add one to the answer
global n,visited
ans = 0

for i in range(1,n+1):
if visited[i]:
continue
# hasn't been visited yet
ans += 1
# mark all of the cities that belong to the same group
dfs(i)
return ans

# main part
n,m = map(int,input().split())
# create a list to record the roads
g = [[] for i in range(n+1)]
for _ in range(m):
u,v = map(int,input().split())
g[u].append(v)
g[v].append(u)
# create a visited list to record whether the city has been visited or not
visited = [False for i in range(n+1)]
visited[0] = True
print(counter())
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 并查集模板题
# 找出x所在树的根节点
def find(x):
if pre[x] != x:
return find(pre[x])
return x
# 判断两个节点是否在同一个树(城市群)上,如果不在,则合并两个城市群并计数
def join(x,y):
global n
x_root = find(x) # 找出x所在树的根节点
y_root = find(y) # 找出y所在树的根节点
if x_root != y_root:
# 两个节点不在同一个树上(城市群),将两个城市群合并,以后再碰到这两个树的节点就不会重复计数了,保证每一颗树只计数一次
pre[x_root] = y_root # 将x_root变成y_root的子节点,合并两树
# 初始时有n个节点,彼此没有路径关系,视为n个城市群
# 随着道路关系的引入,城市群不断合并,n就是城市群的数量
n -= 1

# 主程序
n,m = map(int, input().split())
# 注意序号从1开始
pre = [i for i in range(n+1)]
for _ in range(m):
u,v = map(int,input().split())
join(u,v)
print(n)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 优化后的并查集,在找到x的根节点后直接将x的前驱节点改为根节点,缩短x的子节点查找根节点的路径长度
n, m = map(int, input().split())
p = list(range(n + 1))
def find_root(x):
if p[x] == x:
return x
p[x] = find_root(p[x])
return pre[x]
for i in range(m):
u, v = map(int, input().split())
u_root = find_root(u)
v_root = find_root(v)
if u_root != v_root:
p[u_root] = v_root
n-=1
print(n)

例题 2:修改数组(第10届蓝桥杯省赛真题)

题目描述:

给定一个长度为 NN 的数组 A=[A1,A2,,AN]A = [A_1, A_2, · · ·, A_N],数组中有可能有重复出现的整数。
现在小明要按以下方法将其修改为没有重复整数的数组。小明会依次修改 A2,A3,,ANA_2, A_3, · · ·, A_N
当修改 AiA_i 时,小明会检查 AiA_i 是否在 A1Ai1A_1 ∼ A_{i−1} 中出现过。如果出现过,则小明会给 AiA_i 加上 1 ;如果新的 AiA_i 仍在之前出现过,小明会持续给 AiA_i 加 1 ,直到 AiA_i 没有在 A1Ai1A_1 ∼ A_{i-1} 中出现过。
ANA_N 也经过上述修改之后,显然 AA 数组中就没有重复的整数了。现在给定初始的 AA 数组,请你计算出最终的 AA 数组。

输入格式:

第一行包含一个整数 N。
第二行包含 N 个整数 A1,A2,,ANA_1, A_2, · · ·, A_N

输出格式:

输出 N 个整数,依次是最终的 A1,A2,,ANA_1, A_2, · · ·, A_N

代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 并查集
def find(x):
if x == f[x]:
# 找到还没有出现过的元素
return x
p = x
while p != f[p]:
p = f[p]
f[x] = p
return p

n = int(input())
a = [int(i) for i in input().split()]
f = [i for i in range(1000001)] # 使用a[]中元素的最大值作为并查集数组容量
for i in range(n):
# 更新a[]
a[i] = find(a[i])
# 更新并查集
f[a[i]] = find(a[i]+1)
print(' '.join(list(map(str, a))))

二、Tarjan 算法

(1)算法作用

TarjanTarjan 算法是DFS序和并查集的结合应用,可以高效地求出树上两点的最近公共祖先(LCALCA),求出的 LCALCA 可以用于求树上两点之间的最短距离、树上差分等问题。

(2)算法思路

  1. 由于 TarjanTarjan 算法是一种离线算法,所以要先将所有的查询操作存储起来,等待一并处理。
  2. 以树的根节点作为入口进行DFS遍历,同时利用并查集维护当前节点的父节点
  3. 在遍历当前节点时,标记当前节点
  4. 先遍历当前节点的所有孩子节点,如果未被访问,DFS这个孩子,然后调用并查集将这个孩子的父节点标记为当前节点
  5. 遍历以当前节点为主节点的所有询问请求,如果当前节点的询问请求的另一个节点已经有标记了,那么这个询问的答案就是另一个节点此时的父节点,记录这个答案

(3)算法模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def find(x):
if x == fa[x]:
return x
p = x
while p != fa[p]:
p = fa[p]
# 合并路径版的并查集
fa[x] = p
return p

def tarjan(x):
visited[x] = True
# 遍历所有子节点
for i in e[x]:
if visited[i]:
continue
tarjan(i)
fa[i] = x
# 检查以x为主元素的查询
for t in query[x]:
if visited[t[0]]:
ans[t[1]] = find(t[0])

n,m,s = map(int,input().split())
e = [[] for _ in range(n+1)] # 存储边的关系
visited = [False]*(n+1)
# 并查集
fa = [i for i in range(n+1)]
# 存储查询,元素类型是二元组
query = [[] for _ in range(n+1)]
# 存储结果
ans = [-1]*m
# 接收输入,构造树
for _ in range(n-1):
x,y = map(int,input().split())
e[x].append(y)
e[y].append(x)
# 接收查询信息
for i in range(m):
x,y = map(int,input().split())
# 在记录每一组查询的同时记录查询的次序
query[x].append((y,i))
query[y].append((x,i))

tarjan(s)
# 输出结果
for i in ans:
if i == -1:
continue
print(i)

扩展应用:计算树上两点之间的最短距离

LCALCA 的最基本应用是求树上两点之间的最短距离,公式如下:

dist(x,y)=deep[x]+deep[y]2deep[LCA(x,y)]dist(x,y) = deep[x] + deep[y] - 2*deep[LCA(x,y)]