树状数组

拜读了大佬的讲解博文(树状数组(详细分析+应用),看不懂打死我!),写一篇Python版的笔记巩固消化,附带蓝桥杯历年真题作为例题演示

一、作用

用于快速读取列表中 某个区间内所有元素的和 实现单点修改区间查询
若以差分数组作为a[]则可实现 区间修改单点查询 操作,是一个常用技巧

二、时间复杂度

传统方式

  1. 访问某个元素:O(1)O(1)
  2. 获得某区间元素和:O(n)O(n)

树状数组

  1. 访问某个元素:O(logn)O(logn)
  2. 获得某区间元素和:O(logn)O(logn)

三、规则

通过创建一个列表t,记录以二进制划分的区间内元素的和,其中lowbit(x)的位数决定本节点所处的层数,t[x]保存了以x为根的子树中叶节点的值(即区间的元素和)
通过观察,
a数组具有以下性质:

  1. 下标索引从1开始(切记!!!)
  2. 长度为n
    t数组具有以下性质:
  3. t[x]t[x] 节点覆盖的长度是 lowbit(x)lowbit(x)
  4. t[x]t[x] 的父节点是 t[x+lowbit(x)]t[x + lowbit(x)]
  5. 树的深度为 log2n+1log_2n + 1
  6. t[x]t[x] 节点覆盖的区间为 [xlowbit(x)+1,x][x-lowbit(x)+1, x]t[x]t[x] 也即等于 t[x]t[x] 的子节点区间以后到$ a[x]$ 的所有元素之和!
    t[x]i=xlowbit(x)+1xa[i]t[x] \equiv \sum_{i = x-lowbit(x)+1}^x a[i]

四、创建和维护树状数组的三个基本函数

树状数组不是标准库中的数据结构,而是一个通过特殊函数维护和操作的一维数组。要想在题目中使用树状数组,首先需要创建三个操作函数。以下是这三个函数的详解。

(1)取最低二进制位函数 lowbit()

lowbit()函数用于获取一个正整数在二进制表示下最低位的1与其右侧所有的0所构成的二进制数的数值。
例如 12=2b1100,lowbit(12)=2b100=412 = 2'b1100, lowbit(12) = 2'b100 = 4

1
2
3
# 正负x按位与
def lowbit(x):
return (-x)&x

(2)单点修改函数 add()

为了实现树状数组的单点修改操作,需要创建一个函数add()。
由于每一个树上节点的祖先节点的值都包含了该节点的值,所以在修改某一个点的时候需要从叶子节点开始逐级向上递归修改它所有祖先节点的值。
这里就需要根据当前节点的序号 ii 找出其双亲节点的序号,由树状数组的性质可知其双亲结点的序号为 i+lowbit(i)i+lowbit(i)(见规则2

1
2
3
4
5
def add(x,v):
global n # n = len(t)
while x < n:
t[x] += v
x += lowbit(x)

(3)区间查询函数 ckeck()

建立树状数组后,就可以利用其性质进行快速的区间查询了,由 规则4 可推知,区间[1,x]的元素和等于 t[1]++t[xlowbit(x)]+t[x]t[1] + ··· + t[x-lowbit(x)] + t[x],由此可以使用递推求出区间和

1
2
3
4
5
6
7
# 求出区间[1:x+1]内所有元素的和
def check(x):
ans = 0
while x > 0:
ans += t[x]
x -= lowbit(x)
return ans

以上函数无法指定区间的左端点,为了求出指定端点的区间和,可以使用类似于前缀和的方法求出指定区间的和值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 求出区间[x:y]内所有元素的和
def check(left,right):
ans = 0
x = right
# 先使用原方法求出区间[1:right]的区间和
while x > 0:
ans += t[x]
x -= lowbit(x)
# 然后减去区间[1:left]的元素和,即可获得答案
x = left-1
while x > 0:
ans -= t[x]
x -= lowbit(x)
return ans

五、树状数组整体模板

(1)单点修改、区间查询模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def lowbit(x):
return x&(-x)

def add(x,v):
global n,t
while x < n:
t[x] += v
x += lowbit(x)

def check(left, right):
global t
x = right
ans = 0
while x > 0:
ans += t[x]
x -= lowbit(x)
x = left - 1
while x > 0:
ans -= t[x]
x -= lowbit(x)
return ans

# 创建原数组和树状数组
# 注意树状数组的序号从1开始
a = [0] + [int(i) for i in input().split()]
n = len(a)
t = [0]*n
# 初始化树状数组
# 方法和初始化前缀和数组类似,将每一位的元素加到t[]中
for i in range(1,n):
add(i,a[i])
# 查询修改前的区间和
print(check(2,6))
# 修改原数组中某一元素的值(单点修改)
index,value = map(int,input().split())
add(index, value)
# 查询修改后的区间和(区间查询)
print(check(2,6))
# 具体功能(略),按照题目要求编写

(2)区间修改、单点查询模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def lowbit(x):
return x&(-x)
def add(x,v):
global x,t
while x < n:
t[x] += v
def check(left,right):
global n,t
x = right
while x > 0:
ans += t[x]
x -= lowbit(x)
x = left - 1
while x > 0:
ans -= t[x]
x -= lowbit(x)
return
# 初始化原数组和树状数组
a = [0] + [int(i) for i in input().split()]
n = len(a)
t = [0]*(n+1)
d = [0]*(n+1)
# 用树状数组维护差分数组
for i in range(1,n):
d[i] = a[i] - a[i-1]
add(i,d[i])
# 区间修改
l,r,v = map(int,input().split())
# 结合差分数组修改的原理在树状数组上进行单点修改
# 修改d[l],d[r+1]
add(l,v)
add(r+1,-v)
# 单点查询
# 查询原数组第三个元素的值
print(check(3))

六、例题

例题一:异或和(蓝桥杯第14届省赛)

问题描述:

给一棵含有 nn 个结点的有根树,根结点为 11 ,编号为 ii 的点有点权 aia_i i[1,n](i \in [1,n])。现在有两种操作,格式如下:
1xy1 x y :该操作表示将点 xx 的点权改为 yy
2x2 x :该操作表示查询以结点 xx 为根的子树内的所有点的点权的异或和。
现有长度为 mm 的操作序列,请对于每个第二类操作给出正确的结果。

输入格式:

输入的第一行包含两个正整数 n,mn,m 用一个空格分隔。第二行包含 nn 个整数 $a_1, a_2, … ,a_n
,相邻整数之间使用一个空格分隔。接下来 n1n−1 行,每行包含两个正整数 ui,viu_i, v_i,表示结点 uiu_iviv_i之间有一条边。接下来 mm 行,每行包含一个操作。

输出格式:

输出若干行,每行对应一个查询操作的答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# 求 DFS 序,以便建立树状数组
cnt = 0
def dfs(cur,pre):
# cur 是当前节点的序号,pre是上一个节点的序号
global cnt
cnt += 1
# 记录将当前节点压入栈中的时间戳
seq[cur][0] = cnt
for i in tree[cur]:
if pre != i:
dfs(i,cur)
# 记录将当前元素出栈的时间戳,自此以后的时间戳均与以cur为根节点的树无关
seq[cur][1] = cnt

# 树状数组函数三件套
def lowbit(x):
return x&(-x)

def modify(x,v):
global n
while x <= n:
t[x] ^= v # 计算异或和
x += lowbit(x)

def query(x):
global n
ans = 0
while x > 0:
ans ^= t[x]
x -= lowbit(x)
return ans

# 接收输入,创建数据结构
n,m = map(int,input().split())
# a[]存储每一个点的权值
a = [0] + [int(i) for i in input().split()]

# 用邻接表存储树结构
tree = [[] for i in range(n+1)]
for _ in range(n-1):
u,v = map(int,input().split())
# 注意没说方向,是一个无向边
tree[u].append(v)
tree[v].append(u)

# 创建一个二维数组seq[][]记录DFS序
# 其中seq[i]是有2个元素的列表,两个元素分别是第i个节点入栈和出栈的时间戳
seq = [[0,0] for i in range(n+1)]
dfs(1,0)

# 为DFS序数组创建树状数组,并用a[]的值初始化
t = [0]*(n+1)
for i in range(1,n+1):
modify(seq[i][0], a[i])
for _ in range(m):
instr = [int(i) for i in input().split()]
if instr[0] == 1:
# 修改元素,注意到需要在赋值的同时清除原有元素,所以将原值与新值异或,相当于清除原值
modify(seq[instr[1]][0], a[instr[1]]^instr[2])
# 维护a[],确保其始终保存着每一个节点的当前值
a[instr[1]] = instr[2]
else:
# 输出单点查询结果
print(query(seq[instr[1]][1]) ^ query(seq[instr[1]][0] - 1))

例题 2:逆序对

给定一个数组,求出其中的逆序对数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 树状数组
MAX_SIZE = 1000000
def lowbit(x):
return x&(-x)

def add(x,v):
global MAX_SIZE
while x < MAX_SIZE:
t[x] += v
x += lowbit(x)

def ask(left,right):
ans = 0
x = right
while x > 0:
ans += t[x]
x -= lowbit(x)
x = left - 1
while x > 0:
ans -= t[x]
x -= lowbit(x)
return ans

# 主程序
n = int(input())
a = [int(i) for i in input().split()]
t = [0]*MAX_SIZE
ans = 0
for i in range(n):
ans += ask(a[i]+1,MAX_SIZE-1)
add(a[i],1)
print(ans)

例题 3:压制二元组的总价值

题目链接:压制二元组的总价值(算法赛)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def lowbit(x):
return x&(-x)
def add(tree,x,v):
n = len(tree)
while x < n:
tree[x] += v
x += lowbit(x)
def ask(tree,x):
ans = 0
while x > 0:
ans += tree[x]
x -= lowbit(x)
return ans

# 主程序
n = int(input())
a = [0] + [int(i) for i in input().split()]
b = [0] + [int(i) for i in input().split()]
# 创建一个数组存储每一个a[i]的下标i
a1 = [0]*(n+1)
for i in range(1,n+1):
a1[a[i]] = i
# 创建一个数组存储b[i]在a中的下标
b1 = [0]*(n+1)
for i in range(1,n+1):
b1[i] = a1[b[i]]
# 创建树状数组
t1 = [0]*(n+1) # 计量树状数组
t2 = [0]*(n+1) # 计数树状数组,和求逆序对的思路类似
# 求解过程
ans = 0
for i in range(1,n+1):
add(t1,b1[i],b1[i]) # 记录b数组中下标小于当前元素的元素的和值
add(t2,b1[i],1) # 记录b数组中下标小于当前元素的元素的个数
ans += b1[i]*ask(t2,b1[i]-1) - ask(t1,b1[i]-1)
print(ans)