赤帽子のWARush

プログラミング関係のメモを中心としたサイト

4 August 2024

累積和の演算を一般化する (with XORの性質の話)

AtCoder ABC365 E - Xor Sigma Problemにやられてしまい、レートを溶かしました。今回はその解説の行間を埋めつつ、ついでに累積和を一般化して議論してみます。累積和そのものの丁寧な解説を見たいという方は、別の記事や文献を参照してください。

累積和の一般化(理論編)

まずは実装を抜きにして、累積和の一般化について考えましょう。長さ$N$の数列$A = [ a_1, a_2, \cdots , a_N ]$の累積和$P_x \ (1 \leq x \leq N)$は以下のように定義されます。

\[\begin{align*} S_x = \sum_{i=1}^x a_i \end{align*}\]

また、数列$A$の部分和$A_{ij} = a_i + a_{i+1} + \cdots + a_{j-1} \ (1 \leq i < j \leq N)$ は以下のように表されます。

\[\begin{align*} A_{ij} = \sum_{i=1}^{j-1} a_i = P_{j-1} - P_{i-1} \end{align*}\]

このことから、累積和をすべて事前計算しておけば、区間和も$O(1)$で算出できます。

ここまでの内容については既知とし、この性質を、和に限らない二項演算について一般化することを考えます。$A$の可能な要素をすべて含む集合$S$について以下の性質を満たす演算 $\cdot : S \times S \rightarrow S$を考えます。

知っている方も多いと思われますが、このような集合$S$と演算 $\cdot$ の組を、数学的には「半群」と呼びます。です。ここで、$S$の元からなる列$A = {a_1, a_2, \cdots , a_N}$について、その「累積」$P_x \ (1 \leq x \leq N)$を以下のように定義します。

\[\begin{align*} P_x = a_1 \cdot a_2 \cdot a_3 \cdot \ \cdots \ \cdot a_x \end{align*}\]

後述の通り、これは累積和と同じ要領で高速に計算できます。

また、演算$\cdot$に対し、さらに以下の制約を要請した場合を考えます。

演算がこの2つの制約を加えた3つの条件を満たすとき、$S$と$\cdot$の組を「群」と呼びます。このとき、$A$の部分積

\[\begin{align*} A_{ij} = a_i \cdot a_{i+1} \cdot \ \cdots \ \cdot a_{j-1} \end{align*}\]

は、上記の結合法則を用いることで、以下のように計算することができます。

\[\begin{align*} A_{ij} &= a_i \cdot a_{i+1} \cdot \ \cdots \ \cdot a_{j-1} \\ &= (a_{i-1}^{-1} \cdot a_{i-2}^{-1} \cdot \ \cdots \ \cdot a_1^{-1}) \cdot (a_1 \cdot a_2 \cdot a_3 \cdot \ \cdots \ \cdot a_{j-1}) \\ &= (a_1 \cdot a_2 \cdot \ \cdots \ \cdot a_{i-1})^{-1} \cdot (a_1 \cdot a_2 \cdot a_3 \cdot \ \cdots \ \cdot a_{j-1}) \\ &= P_{i-1}^{-1} \cdot P_{j-1} \end{align*}\]

この結果は、逆元の存在が保証される演算であれば、列$A$の「累積」を事前にすべて求めておくことで、任意の区間積を$O(1)$で導出できることを意味しています。逆に言えば、この方法では半群の性質を満たす演算についての区間積を効率的に求めることはできません。これよりもゆるい制約、たとえばモノイドの性質を満たす演算などにおいて区間積の計算を高速に行えるデータ構造に、セグメント木やDisjoint Sparse Tableなどがありますが、本記事の内容からは逸れるためここでは扱いません。

累積を一般化したプログラムコード

通常の累積和と、演算を一般化したコードを比較してみます。例示にはPythonを使いますが、他言語でも本質は同じはず。

一般化前

数列$A$について、前から順に要素を足した値を管理することで、すべての$x$についての累積和 \(S_x = \sum_{i=1}^x a_i\) を$O(N)$で導出可能です。この事前計算をあらかじめ行うことで、以後は$S_x$を$O(1)$で参照可能です。

def prefix_sum(L): #1次元リストの累積和を返す
	ret = [L[0]]
	for i in range(1, len(L)):
		ret.append(ret[i-1]+L[i])
	return ret

出力retは、引数にとった数列$L$の長さをNとしたとき、各$1 \leq x \leq N$について、累積和$S_x$が格納された長さ$N$の数列です。$S_x = S_{x-1} + L_i$を利用して、$S_x$を$x$の小さい方から順番に$O(1)$で導出しています。

一般化後

前述のコードは、いわゆる「足し算」に関するものに限定した実装でした。これを一般の集合と演算についての累積に対応した関数に書き換えると、以下のようになります。

def prefix(L, op): #1次元リストの累積を返す
	ret = [L[0]]
	for i in range(1, len(L)):
		ret.append(op(ret[i-1], L[i]))
	return ret

前述の通り、演算$op$は半群の性質を満たす必要があります。

区間積$A_{ij}$を導出する場合、前述の通り$ret[i-i]^{-1} \cdot ret[j-1]$を求めれば十分です。ただしこれを行う場合、演算$op$は群の性質を満たす必要があります。

例題

アルゴリズムコンテストサイト「AtCoder」上の問題を例に、上述した累積の一般化関数を用いた解答を示します。本当はもう少し簡単な問題がよかったけど見つかりませんでした

ABC125 C - GCD on Blackboard

問題文はこちら

一般化した累積に載る演算として、整数の最大公約数gcdを挙げることができます。一般に、gcdに逆元が存在しない(モノイドの演算としての性質は満たすが、群の演算としての性質は満たさない)ため、これを素直に用いて区間gcdを高速に求めるといった処理は不可能です。

この問題は、列の左右両方から累積GCDを事前計算することで高速に求めることができます。現代競技プログラマーなら、セグメント木などを用いることで思考停止で解く方も多いでしょう。

以下に、ごく簡単な方針の説明を含めた解答例を示します。

解答例
''' 方針
A_iを変更する場合のgcdの最大値は、「A_i以外のすべての要素の最大公約数」に等しい。
よって、A_1からA_{i-1}までの最大公約数と、A_{i+1}からA_Nまでの最大公約数がO(1)でわかればよい。
'''

from math import gcd

def prefix(L, op): #1次元リストの累積を返す
	ret = [L[0]]
	for i in range(1, len(L)):
		ret.append(op(ret[i-1], L[i]))
	return ret

#入力
N = int(input())
A = list(map(int, input().split()))

#Aを逆順にした列
Arev = list(reversed(A))

#それぞれ累積gcdを取る
pA = prefix(A, gcd)
pArev = prefix(Arev, gcd)

#答えを求める
ans = max(pA[-2], pArev[-2]) 
for i in range(N-2):
    ans = max(ans, gcd(pA[i], pArev[-i-3])) #A[i+1]以外の累積gcdで更新

print(ans)

ABC365 E - Xor Sigma Problem

今度は累積xorを考える問題です。正答率から見た集計上の難易度は先程の問題の方がわずかに高いことになっていますが、実際にはこちらの方がかなり難しいと思います。

問題文はこちら

じっくり考えましょう。まず、この問題のキモは「bitごとに分けて考える」というところにあります。制約より$max(X_A) < 2^{27}$なので、0から27までの28ビットを考えれば十分です。$k \ (0 \leq k \leq 27)$ビット目について、区間xorの総和を求め、それに$2^k$をかけた値を求めることで、本問題の答えを導出することができます。

制約的には、この「区間xorの総和」を$O(N)$くらいで求められると望ましいですが、愚直にすべてを計算していると間に合いません。そこで使うのが累積xorです。

与えられた整数列$A$について、「kビット目の累積xorを求める」というアルゴリズムは、以下のように実装可能です。

from operator import xor

#入力
N = int(input())
A = list(map(int, input().split()))

def prefix(L, op): #1次元リストの累積和を返す
	ret = [L[0]]
	for i in range(1, len(L)):
		ret.append(op(ret[i-1], L[i]))
	return ret

for k in range(28):
	B = []
	for a in A:
		B.append((a >> i) & 1) # aのiビット目を列Bに追加
	pB = prefix(B, xor) #Bの累積xorを導出

xorには必ず逆元が存在する(自分自身が逆元となる)ため、演算xorは群の演算としての性質を満たします。このことから、今回求めたい区間積は、kビット目の累積xorの最初にゼロを足した長さ$N+1$の数列を$pB_k$と書き、xor演算を$\oplus$で表現すると、以下のように書き換えられます。

\[\begin{align*} \sum_{k = 0}^{27} \sum_{i = 0}^{N-2} \sum_{j = i+2}^{N} (pB_k[j] \oplus pB_k[i]) \times 2^k \end{align*}\]

ところが、これでもまだ計算量は$O(28\times N^2)$ほどあり、本問題の制約では間に合いません。そこで、$pB_k$の任意の要素が0または1からなることに注目します。1と0からなる数列$L$に含まれる0と1の個数をそれぞれ$c_1$, $c_2$とおくと、数列$L$の任意の異なる二項間のxorの和は、以下のように計算することができます。

\[\begin{align*} \sum_{i = 1}^{c_0+c_1-1} \sum_{j = i+1}^{c_0+c_1} L_i \oplus L_j &= \frac{1}{2} \sum_{i = 1}^{c_0+c_1} \sum_{j = 1}^{c_0+c_1} L_i \oplus L_j \\ &= \frac{1}{2} [ (c_0*c_0)*(0 \oplus 0) + (c_0*c_1)*(0 \oplus 1) \\ & \ \ \ \ \ + (c_1*c_0)*(1 \oplus 0) + (c_1*c_1)*(1 \oplus 1) ] \\ &= c_0*c_1 \end{align*}\]

ここで、1項目の式変換では、$L_i \oplus L_i = 0$を利用しました。この結果は極めて重要で、0と1の個数さえわかれば、それを用いて$L$の任意の二項間のxorの総和を$O(1)$で求めることができます。さて、この結果を、先の問題の区間積に代入してみましょう。$pB_{k}$における0の個数を$c_{k0}$と置くと、

\[\begin{align*} \sum_{k = 0}^{27} \sum_{i = 0}^{N-2} \sum_{j = i+2}^{N} (pB_k[j-1] \oplus pB_k[i-1]) \times 2^k &= \sum_{k = 0}^{27} \left[ \sum_{i = 0}^{N-1} \sum_{j = i+1}^{N} (pB_k[j-1] \oplus pB_k[i-1]) - \sum_{i = 1}^{N} pB_k[i] \right] \times 2^k \\ &= \sum_{k = 0}^{27} \left[ c_{k0}*(N+1-c_{k0}) - \sum_{i = 0}^{N-1} B_k[i] \right] \times 2^k \end{align*}\]

となります。ただし、最後の$B_k$は、累積xorを出す前の、$A$の各要素の$k$ビット目の値を取り出した数列$B$に対応します。 配列内の0の個数は$O(N)$で求められますから、この結果は、本問題が$O(28 \times N)$で求められることを示したものにほかなりません。(記述のルール上は28を定数とみなすべきかもしれませんが、今回は値が大きいのと、その由来が$\log max(A)$であることから、あえて$O$の内部で明示しています。)

最後に、以上の結果を反映した、この問題の答えを求めるコード全体を提示します。

解答例
from operator import xor

#入力
N = int(input())
A = list(map(int, input().split()))

def prefix(L, op): #1次元リストの累積和を返す
    ret = [L[0]]
    for i in range(1, len(L)):
        ret.append(op(ret[i-1], L[i]))
    return ret

ans = 0
for k in range(28): #すべてのビットについて
    B = []
    for a in A:
        B.append((a >> k) & 1) # aのkビット目を列Bに追加
    pB = [0] + prefix(B, xor) #Bの累積xorを導出
    c = pB.count(0)
    ans += (c*(N+1-c)-sum(B))*2**k

print(ans)

xorの性質を丁寧にうまく活用することが求められる問題で、累積和一般化の入門的な問題とは言い難いですが、このようなレベルの問題もシンプルに解けるということで。

(やっぱ最初の問題よりこっちの方がはるかに難しいと思う)

注意

この記事についてのご質問、間違いのご指摘等を歓迎します。Twitter: @E_Z_Marioまでどうぞ。