takuroooのブログ

勉強したこととか

転倒数とBinary Indexed Treeについて考える

前回のAtCoderのコンテストで転倒数というものが出題された。 atcoder.jp 転倒数というものを知らなかったので転倒数についてまとめてみる。またコンテストの問題をBinary Indexed Treeで解いてどういう処理の流れになっているかを整理する。コンテストの問題はBITを使わなくても解けるがBITを使うと高速に転倒数をカウントできるとのこと。(コンテストのテストケースで7倍速度の差があった。)

転倒数とは

ある数列に関して以下の条件を満たす組の数のこと
「自分より右にあって かつ 自分より値が小さい」

例えば、[3 1 2]という数列Aがあったとする。この数列の転倒数は2になる。 数え方は以下の通り。

数列Aの左端から見ていくと、
* 3からみたときの転倒数は、1と2の2つ
* 次に1から見たときの転倒数は一つもないので0
* 最後の2には右側に数字がないので転倒数は0
ということで転倒数は合計2つになる。

表っぽく表現するとこんな感じになる。
一番上の行が数列になっていて、その下の行が転倒数の組み合わせになっている。

f:id:takuroooooo:20190826083631p:plain

問題B - Kleene Inversion

前回のAtCoderの問題「B - Kleene Inversion」では、この数列AをK回繰り返したときの数列Bの転倒数を求めるというものだった。

先の数列A=[3 1 2]を使って表現すると、
K=1: B=[3 1 2]
K=2: B=[3 1 2 3 1 2]
K=3: B=[3 1 2 3 1 2 3 1 2]
となる。(先ほどの例はK=1のときの転倒数を求めたことになる。)

この問題は解説動画の33分あたりにある通り、数列A内で発生する転倒数2つの数列A間で発生する転倒数を数えることで解くことができる。

www.youtube.com

ここで数列A内で発生する転倒数2つの数列A間で発生する転倒数とはどういうものかを整理してみる。

数列A=[3 1 2]内で発生する転倒数

これは数列A=[3 1 2]の中で数えた転倒数の数のことで、最初の例で数えた転倒数のことを指す。

f:id:takuroooooo:20190826083631p:plain

2つの数列A間で発生する転倒数

これはK=2のときに発生する転倒数の中で2つの数列を跨っているもののこと。
数列A=[3 1 2]を2回繰り返したときの数列[3 1 2 3 1 2]の転倒数の合計は以下のように7つになる。

f:id:takuroooooo:20190826085835p:plain

この中で2つの数列を跨っている転倒数は3つある。
(以下の表の青枠の3つ)

f:id:takuroooooo:20190826090757p:plain

他の4つは、数列A=[3 1 2]内で発生する転倒数のこと。
(以下の表の赤枠)

f:id:takuroooooo:20190826091055p:plain

こうやってみると赤枠はKの数だけ発生することが分かる。
また青枠の数はK=2なら1つ、K=3なら3つ、K=4なら6つというように組み合わせの公式で求めることができる。

2つの数列A間で発生する転倒数というのは数列A=[3 1 2]を降順にソートした状態[3 2 1]の転倒数と一致する。

f:id:takuroooooo:20190826093125p:plain

Binary Indexed Treeによる転倒数のカウントについて考えてみる

転倒数はBITを使ってカウントすることができる。
以下のコードはB - Kleene InversionをBITで解いたもの。言語はPython

MOD = 10 ** 9 + 7
 
 
def MAP():
    return list(map(int, input().split()))
 
 
class BinaryIndexedTree:
    def __init__(self, tree_size):
        self.tree_size = tree_size
        self.tree = [0] * (self.tree_size + 1)
 
    def add(self, i, x):
        while 0 < i <= self.tree_size:
            self.tree[i] += x
            i += i & -i
 
    def sum(self, i):
        s = 0
        while 0 < i:
            s += self.tree[i]
            i -= i & -i
        return s
 
 
def count_inversion(A, max_num):
    bit = BinaryIndexedTree(max_num)
    x = 0
    for i, a in enumerate(A, start=1):
        bit.add(a, 1)
        x += i - bit.sum(a)
    return x
 
 
def main():
    N, K = MAP()
    A = MAP()
 
    tento_int = count_inversion(A, 2000)  # Aの内部で発生する転倒数
    tento_ext = count_inversion(sorted(A, reverse=True), 2000)  # AiとAjの間で発生する転倒数
 
    x = tento_int * K
    y = tento_ext * K * (K - 1) // 2
    print((x + y) % MOD)
 
 
if __name__ == "__main__":
    main()

数列[3 1 2]の転倒数を数える場合、上記コードの標準入力には、

3 1
3 1 2

を与える。

BITについては、Binary indexed treeを参照。

BITメソッド

BIT.add(i,x)

BITはadd(i,x)でインデックスiに値xを登録して、sum(i)でインデックス1〜iまでの累積和を計算するもの。 これをうまく使うと、 sum(i)
「自分と自分より左にある かつ 自分以下の値」
をカウントできる。
自分とはインデックスiのことで、例えばsum(3)は数字1と2と3がそれまで何個登場したか(何回add(i,x)されているか)を計算する。

BIT.sum(i)

これまでadd(i,x)で登録した数字の数をtotalとすると、total-sum(i)
「自分と自分より左にある かつ 自分を超えている値」
をカウントできる。
これは転倒数の定義と同じ意味のことを言っているので、BITで転倒数が求まることが分かる。

count_inversion()

コード中のcount_inversion()が実際にどうやってBITで転倒数をカウントしているかを整理する。
整理しやすいようにBITのサイズを4とする。(コードでは2000になっている。)

BITの最初の状態は 以下のように0初期化されている。
インデックスは1始まりで青マスがBITを表現している。 f:id:takuroooooo:20190826093741p:plain

1回目のループ(数列[3 1 2]の3を処理する)

最初のループでi=1a=3が入り、bit.add(a, 1)を実行するとBITクラスのself.tree

f:id:takuroooooo:20190826094303p:plain

となる。
緑マスは更新されたマス。(1が加算された。)
次に x += i - bit.sum(a)
で転倒数を求めている。

bit.sum(a)は、ループ1回目はa=3なのでbit.sum(3)になる。これは図の赤い数字の総和を返すので0+1=1が返る。

これは先にも触れた通り、 「自分と自分より左にある かつ 自分以下の値」 をカウントした結果になる。これは数列[3 1 2]の赤文字の中で3以下の数字をカウントしたものという意味なので答えは1。

i=1なのでx += i - bit.sum(a)xに0を加算している。
iはこれまでbit.add(a, 1)で登録した数字の総数を意味している。
i - bit.sum(a)「自分と自分より左にある かつ 自分を超えている値」を表しているので、0という結果が合っていることが分かる。([3 1 2]赤文字の中で3を超えている数字はない。)

2回目のループ(数列[3 1 2]の1を処理する)

2回目のループでi=2a=1が入り、bit.add(a, 1)を実行するとBITクラスのself.tree

f:id:takuroooooo:20190826102153p:plain

となる。
1回目のループと同様にx += i - bit.sum(a)で転倒数を求めている。

bit.sum(1)は赤い数字の合計なので1が返る。
これは数列[3 1 2]の赤文字の中で1以下の数字をカウントしたものという意味なので答えは1。

i=2なのでx += i - bit.sum(a)xに1を加算している。([3 1 2]赤文字の中で1を超えているものは3だけなので1つ。)

3回目のループ(数列[3 1 2]の2を処理する)

3回目のループでi=3a=2が入り、bit.add(a, 1)を実行するとBITクラスのself.tree

f:id:takuroooooo:20190826103159p:plain

となる。
1,2回目のループと同様にx += i - bit.sum(a)で転倒数を求めている。

bit.sum(2)は赤い数字の合計なので2が返る。
これは数列[3 1 2 ]の赤文字の中で2以下の数字をカウントしたものという意味なので答えは2。

i=3なのでx += i - bit.sum(a)xに1を加算している。([3 1 2 ]赤文字の中で2を超えているもの3だけなので1つ。)

count_inversionが返す値

3回のループでx=2になっている。
これは数列[3 1 2]の転倒数と一致する。

参考リンク