树状数组 简介 树状数组和线段树具有相似的功能,但他俩毕竟还有一些区别:树状数组能有的操作,线段树一定有;线段树有的操作,树状数组不一定有。但是树状数组的代码要比线段树短,思维更清晰,速度也更快,在解决一些单点修改的问题时,树状数组是不二之选。
原理 下面这张图展示了树状数组的工作原理:
这个结构和线段树有些类似:用一个大节点表示一些小节点的信息,进行查询的时候只需要查询一些大节点而不是所有的小节点。
最上面的八个方块就代表数组 。
他们下面的参差不齐的剩下的方块就代表数组 的上级—— 数组。
从图中可以看出: 管理的是 , ; 管理的是 , , , ; 管理的是 , ; 则管理全部 个数。
如果要计算数组 的区间和,比如说要算 ~ 的区间和,可以采用类似倍增的思想:
从 开始往前跳,发现 ( 我也不确定是多少,算起来太麻烦,就意思一下)只管 这个点,那么你就会找 ,发现 管的是 & ;那么你就会直接跳到 , 就会管 ~ 这些数,下次查询从 往前找,以此类推。
用法及操作 那么问题来了,怎么知道 管理的数组 中的哪个区间呢? 这时,我们引入一个函数——lowbit
:
// C++ Version
int lowbit ( int x ) {
// x 的二进制表示中,最低位的 1 的位置。
// lowbit(0b10110000) == 0b00010000
// ~~~^~~~~
// lowbit(0b11100100) == 0b00000100
// ~~~~~^~~
return x & - x ;
}
# Python Version
def lowbit ( x ):
"""
x 的二进制表示中,最低位的 1 的位置。
lowbit(0b10110000) == 0b00010000
~~~^~~~~
lowbit(0b11100100) == 0b00000100
~~~~~^~~
"""
return x & - x
注释说明了 lowbit
的意思,对于 : 发现第一个 以及他后面的 组成的二进制是 对应的十进制是 ,所以 一共管理 个 数组中的元素。
在常见的计算机中,有符号数采用补码表示。在补码表示下,数 x
的相反数 -x = ~x + 1
。
使用 lowbit 函数,我们可以实现很多操作,例如单点修改,将 加上 ,只需要更新 的所有上级:
// C++ Version
void add ( int x , int k ) {
while ( x <= n ) { // 不能越界
c [ x ] = c [ x ] + k ;
x = x + lowbit ( x );
}
}
# Python Version
def add ( x , k ):
while x <= n : # 不能越界
c [ x ] = c [ x ] + k
x = x + lowbit ( x )
前缀求和:
// C++ Version
int getsum ( int x ) { // a[1]..a[x]的和
int ans = 0 ;
while ( x >= 1 ) {
ans = ans + c [ x ];
x = x - lowbit ( x );
}
return ans ;
}
# Python Version
def getsum ( x ): # a[1]..a[x]的和
ans = 0
while x >= 1 :
ans = ans + c [ x ]
x = x - lowbit ( x )
return ans
区间加 & 区间求和 若维护序列 的差分数组 ,此时我们对 的一个前缀 求和,即 ,由差分数组定义得
进行推导
区间和可以用两个前缀和相减得到,因此只需要用两个树状数组分别维护 和 ,就能实现区间求和。
代码如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 // C++ Version
int t1 [ MAXN ], t2 [ MAXN ], n ;
inline int lowbit ( int x ) { return x & ( - x ); }
void add ( int k , int v ) {
int v1 = k * v ;
while ( k <= n ) {
t1 [ k ] += v , t2 [ k ] += v1 ;
k += lowbit ( k );
}
}
int getsum ( int * t , int k ) {
int ret = 0 ;
while ( k ) {
ret += t [ k ];
k -= lowbit ( k );
}
return ret ;
}
void add1 ( int l , int r , int v ) {
add ( l , v ), add ( r + 1 , - v ); // 将区间加差分为两个前缀加
}
long long getsum1 ( int l , int r ) {
return ( r + 1l l ) * getsum ( t1 , r ) - 1l l * l * getsum ( t1 , l - 1 ) -
( getsum ( t2 , r ) - getsum ( t2 , l - 1 ));
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 # Python Version
t1 = [ 0 ] * MAXN , t2 = [ 0 ] * MAXN ; n = 0
def lowbit ( x ):
return x & ( - x )
def add ( k , v ):
v1 = k * v
while k <= n :
t1 [ k ] = t1 [ k ] + v ; t2 [ k ] = t2 [ k ] + v1
k = k + lowbit ( k )
def getsum ( t , k ):
ret = 0
while k :
ret = ret + t [ k ]
k = k - lowbit ( k )
return ret
def add1 ( l , r , v ):
add ( l , v )
add ( r + 1 , - v )
def getsum1 ( l , r ):
return ( r ) * getsum ( t1 , r ) - l * getsum ( t1 , l - 1 ) - \
( getsum ( t2 , r ) - getsum ( t2 , l - 1 ))
Tricks 建树:
每一个节点的值是由所有与自己直接相连的儿子的值求和得到的。因此可以倒着考虑贡献,即每次确定完儿子的值后,用自己的值更新自己的直接父亲。
// C++ Version
// O(n)建树
void init () {
for ( int i = 1 ; i <= n ; ++ i ) {
t [ i ] += a [ i ];
int j = i + lowbit ( i );
if ( j <= n ) t [ j ] += t [ i ];
}
}
# Python Version
def init ():
for i in range ( 1 , n + 1 ):
t [ i ] = t [ i ] + a [ i ]
j = i + lowbit ( i )
if j <= n :
t [ j ] = t [ j ] + t [ i ]
查询第 小/大元素。在此处只讨论第 小,第 大问题可以通过简单计算转化为第 小问题。
参考 "可持久化线段树" 章节中,关于求区间第 小的思想。将所有数字看成一个可重集合,即定义数组 表示值为 的元素在整个序列重出现了 次。找第 大就是找到最小的 恰好满足
因此可以想到算法:如果已经找到 满足 ,考虑能不能让 继续增加,使其仍然满足这个条件。找到最大的 后, 就是所要的值。 在树状数组中,节点是根据 2 的幂划分的,每次可以扩大 2 的幂的长度。令 表示当前的 所代表的前缀和,有如下算法找到最大的 :
求出 计算 如果 ,则此时扩展成功,将 累加到 上;否则扩展失败,对 不进行操作 将 减 1,回到步骤 2,直至 为 0 1
2
3
4
5
6
7
8
9
10
11
12
13 // C++ Version
// 权值树状数组查询第k小
int kth ( int k ) {
int cnt = 0 , ret = 0 ;
for ( int i = log2 ( n ); ~ i ; -- i ) { // i 与上文 depth 含义相同
ret += 1 << i ; // 尝试扩展
if ( ret >= n || cnt + t [ ret ] >= k ) // 如果扩展失败
ret -= 1 << i ;
else
cnt += t [ ret ]; // 扩展成功后 要更新之前求和的值
}
return ret + 1 ;
}
1
2
3
4
5
6
7
8
9
10
11
12 # Python Version
# 权值树状数组查询第 k 小
def kth ( k ):
cnt = 0 ; ret = 0
i = log2 ( n ) # i 与上文 depth 含义相同
while ~ i :
ret = ret + ( 1 << i ) # 尝试扩展
if ret >= n or cnt + t [ ret ] >= k : # 如果扩展失败
ret = ret - ( 1 << i )
else :
cnt = cnt + t [ ret ] # 扩展成功后 要更新之前求和的值
return ret + 1
时间戳优化:
对付多组数据很常见的技巧。如果每次输入新数据时,都暴力清空树状数组,就可能会造成超时。因此使用 标记,存储当前节点上次使用时间(即最近一次是被第几组数据使用)。每次操作时判断这个位置 中的时间和当前时间是否相同,就可以判断这个位置应该是 0 还是数组内的值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 // C++ Version
// 时间戳优化
int tag [ MAXN ], t [ MAXN ], Tag ;
void reset () { ++ Tag ; }
void add ( int k , int v ) {
while ( k <= n ) {
if ( tag [ k ] != Tag ) t [ k ] = 0 ;
t [ k ] += v , tag [ k ] = Tag ;
k += lowbit ( k );
}
}
int getsum ( int k ) {
int ret = 0 ;
while ( k ) {
if ( tag [ k ] == Tag ) ret += t [ k ];
k -= lowbit ( k );
}
return ret ;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 # Python Version
# 时间戳优化
tag = [ 0 ] * MAXN ; t = [ 0 ] * MAXN ; Tag = 0
def reset ():
Tag = Tag + 1
def add ( k , v ):
while k <= n :
if tag [ k ] != Tag :
t [ k ] = 0
t [ k ] = t [ k ] + v
tag [ k ] = Tag
k = k + lowbit ( k )
def getsum ( k ):
ret = 0
while k :
if tag [ k ] == Tag :
ret = ret + t [ k ]
k = k - lowbit ( k )
return ret
例题 build 本页面最近更新:2022/3/29 23:37:24 ,更新历史 edit 发现错误?想一起完善? 在 GitHub 上编辑此页! people 本页面贡献者:ananbaobeichicun , HeRaNO , Ir1d , ouuan , ranwen , wangdehu , Xeonacid , Zhoier , Chrogeek , corchis-S , countercurrent-time , Enter-tainer , fafafa114 , H-J-Granger , iamtwz , ksyx , mcendu , NachtgeistW , Nemodontcry , sshwy , SukkaW , Suyun514 , Weijun-Lin , Ycrpro copyright 本页面的全部内容在 CC BY-SA 4.0 和 SATA 协议之条款下提供,附加条款亦可能应用