競プロ典型 90 問 - PyPy コード

この記事は

  • 競プロ典型 90 問 のコード置き場です。
  • 基本的に PyPy3 で書いています。
  • コンテスト中は他人の Submission が見れないので何かのお役に立てば。
  • 他の人に見やすいようにとかはあまり考えずに書いてるので、通せないときの参考ぐらいで。
  • 難しめのものを優先的に置いていますが、要望があれば他の問題も置きます(Twitter 等で教えてください)。
  • 解法の解説も需要あれば書くかもです(公式解説が充実しているのでいらない気がしますが)。
  • AC コードのリンクも貼っていますが、コンテスト中は本人しか飛べないので注意(私の確認用に貼っています)。
  • もっときれいなコードが書けたら変えるかもです。
  • 他の方の解法

コード一覧

※ クリックで開きます

005 - Restricted Digits(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:762 Byte
  • 実行時間:480 ms
P = 10 ** 9 + 7
N, B, K = map(int, input().split())
C = [int(a) for a in input().split()]
A = [0] * B
for c in C:
    A[c%B] += 1

def prod(a, b):
    re = [0] * B
    for i, aa in enumerate(a):
        for j, bb in enumerate(b):
            k = (i + j) % B
            re[k] = (re[k] + aa * bb) % P
    return re

def poww(a, n):
    def rotate(_a, _r):
        _re = [0] * B
        for _i, _t in enumerate(_a):
            _re[_i*_r%B] = (_re[_i*_r%B] + _t) % P
        return _re
        
    re = [1] + [0] * (B - 1)
    aa = a[:]
    r = 10
    while n:
        if n % 2:
            re = prod(aa, rotate(re, r))
        aa = prod(aa, rotate(aa, r))
        n //= 2
        r = r * r % B
    return re

print(poww(A, N)[0])

017 - Crossing Segments(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:3310 Byte
  • 実行時間:842 ms
class SegmentTree():
    def __init__(self, init, unitX, f):
        self.f = f # (X, X) -> X
        self.unitX = unitX
        self.f = f
        if type(init) == int:
            self.n = init
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * (self.n * 2)
        else:
            self.n = len(init)
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * self.n + init + [unitX] * (self.n - len(init))
            for i in range(self.n-1, 0, -1):
                self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
        
    def update(self, i, x):
        i += self.n
        self.X[i] = x
        i >>= 1
        while i:
            self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
            i >>= 1
    
    def add(self, i, x=1):
        i += self.n
        self.X[i] += x
        i >>= 1
        while i:
            self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
            i >>= 1
    
    def getvalue(self, i):
        return self.X[i + self.n]
    
    def getrange(self, l, r):
        l += self.n
        r += self.n
        al = self.unitX
        ar = self.unitX
        while l < r:
            if l & 1:
                al = self.f(al, self.X[l])
                l += 1
            if r & 1:
                r -= 1
                ar = self.f(self.X[r], ar)
            l >>= 1
            r >>= 1
        return self.f(al, ar)
    
    # Find r s.t. calc(l, ..., r-1) = True and calc(l, ..., r) = False
    def max_right(self, l, z):
        if l >= self.n: return self.n
        l += self.n
        s = self.unitX
        while 1:
            while l % 2 == 0:
                l >>= 1
            if not z(self.f(s, self.X[l])):
                while l < self.n:
                    l *= 2
                    if z(self.f(s, self.X[l])):
                        s = self.f(s, self.X[l])
                        l += 1
                return l - self.n
            s = self.f(s, self.X[l])
            l += 1
            if l & -l == l: break
        return self.n
    
    # Find l s.t. calc(l, ..., r-1) = True and calc(l-1, ..., r-1) = False
    def min_left(self, r, z):
        if r <= 0: return 0
        r += self.n
        s = self.unitX
        while 1:
            r -= 1
            while r > 1 and r % 2:
                r >>= 1
            if not z(self.f(self.X[r], s)):
                while r < self.n:
                    r = r * 2 + 1
                    if z(self.f(self.X[r], s)):
                        s = self.f(self.X[r], s)
                        r -= 1
                return r + 1 - self.n
            s = self.f(self.X[r], s)
            if r & -r == r: break
        return 0
    
    def debug(self):
        print("debug")
        print([self.getvalue(i) for i in range(min(self.n, 20))])

import sys
input = lambda: sys.stdin.readline().rstrip()
N, M = map(int, input().split())
f = lambda x, y: x + y
unit = 0
st = SegmentTree(N, unit, f)
X = []
for _ in range(M):
    l, r = map(int, input().split())
    X.append((l - 1, r - 1))

X.sort(key = lambda x: (x[1] << 20) + x[0])
ans = 0
for l, r in X:
    ans += st.getrange(l, r - 1)
    st.add(l, -1)
    st.add(r - 1, 1)
print(ans)

023 - Avoid War(★7)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
公式解法ほぼそのままです。

dict を使って書くとかなりきれいに書ける(コードの calc )んですが、残念ながら TLE してしまいます(おそらく最悪ケースで 10~11s ぐらい)。

list で書き換えると 2.5 秒切るぐらいで通ります。

  • 使用言語:PyPy3 (7.3.0)
  • コード長:2837 Byte
  • 実行時間:2473 ms
# dict だときれいに書けるけど遅い
def calc(H, W, X):
    m = 1 << W
    D = {0: 1}
    for i in range(H):
        for j in range(W):
            mm = m + 7 if 0 < j < W - 1 else 6 if j == 0 else m + 3
            nD = {}
            for a, d in D.items():
                aa = a >> 1
                if aa in nD:
                    nD[aa] += d % P
                else:
                    nD[aa] = d % P
            if not X[i][j]:
                for a, d in D.items():
                    if not a & mm:
                        aa = (a >> 1) ^ m
                        if aa in nD:
                            nD[aa] += d % P
                        else:
                            nD[aa] = d % P
            D = nD
    ans = 0
    for d in D.values():
        ans += d
    return ans % P

# dict は最初だけにして後は list にする
def calc2(H, W, X):
    S = {0}
    T = set()
    for j in range(W):
        m = 1 << j + 1
        nT = set()
        for s in S:
            nT.add(s ^ m)
        S |= T
        T = nT
    S |= T
    nS = set()
    for s in S:
        nS.add(s ^ 1)
    S |= nS
    SS = sorted(S)
    DS = {a: i for i, a in enumerate(SS)}
    M = len(SS)
    
    # 遷移を最初に求めておく
    D = {0: 1}
    Z = [[] for _ in range(W)]
    for j, z in enumerate(Z):
        mm = 6 if j == 0 else 1 ^ (1 << W - 1) ^ (1 << W) if j == W - 1 else 1 ^ (1 << j) ^ (1 << j + 1) ^ (1 << j + 2)
        j1 = (j - 1) % W + 1
        m = 1 << j1
        im = ((1 << W + 1) - 2) ^ m
        for k, a in enumerate(SS):
            na1 = (a & im) ^ m if a & 1 else a & im
            naa1 = DS[na1] if na1 in DS else -1
            if a & mm:
                naa2 = -1
            else:
                na2 = ((a & im) ^ m if a & 1 else a & im) ^ 1
                naa2 = DS[na2] if na2 in DS else -1
            z.append((naa1, naa2))
    
    # あとは list で処理
    Y = [0] * M
    Y[0] = 1
    for i in range(H):
        for j, z in enumerate(Z):
            nY = [0] * M
            if X[i][j]:
                for k, (z1, z2) in enumerate(z):
                    if Y[k]:
                        if z1 >= 0:
                            nY[z1] = (nY[z1] + Y[k]) % P
            else:
                for k, (z1, z2) in enumerate(z):
                    if Y[k]:
                        if z1 >= 0:
                            nY[z1] = (nY[z1] + Y[k]) % P
                        if z2 >= 0:
                            nY[z2] = (nY[z2] + Y[k]) % P
            Y = nY
    ans = 0
    return sum(Y) % P

P = 10 ** 9 + 7
H, W = map(int, input().split())
X = [[1 if a == "#" else 0 for a in input()] for _ in range(H)]

if W <= 10:
    print(calc(H, W, X))
else:
    print(calc2(H, W, X))

025 - Digit Product Equation(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:612 Byte
  • 実行時間:83 ms
def f(n):
    s = 1
    while n:
        s *= n % 10
        n //= 10
    return s

N, B = map(int, input().split())
m = N - B
ans = 1 if f(B) == 0 and B <= N else 0
s2 = 1
for i2 in range(34):
    if s2 > m: break
    s3 = s2
    for i3 in range(23):
        if s3 > m: break
        s5 = s3
        for i5 in range(12):
            if s5 > m: break
            s7 = s5
            for i7 in range(12):
                if s7 > m: break
                if f(B + s7) == s7:
                    ans += 1
                s7 *= 7
            s5 *= 5
        s3 *= 3
    s2 *= 2
print(ans)

029 - Long Bricks(★5)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:9316 Byte
  • 実行時間:1349 ms
class LazySegmentTree():
    def __init__(self, init, unitX, unitA, f, g, h, Z = None):
        self.f = f # (X, X) -> X
        self.g = g # (X, A, size) -> X
        self.h = h # (A, A) -> A
        self.unitX = unitX
        self.unitA = unitA
        self.f = f
        if type(init) == int:
            self.n = init
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * (self.n * 2)
            if not Z:
                self.size = [1] * (self.n * 2)
            else:
                self.size = [0] * self.n + [b - a for a, b in zip(Z, Z[1:])]
                self.size += [0] * (self.n * 2 - len(self.size))
        else:
            self.n = len(init)
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * self.n + init + [unitX] * (self.n - len(init))
            if not Z:
                self.size = [0] * self.n + [1] * len(init) + [0] * (self.n - len(init))
            else:
                self.size = [0] * self.n + [b - a for a, b in zip(Z, Z[1:])]
                self.size += [0] * (self.n * 2 - len(self.size))
            for i in range(self.n-1, 0, -1):
                self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
    
        for i in range(self.n - 1, 0, -1):
            self.size[i] = self.size[i*2] + self.size[i*2|1]
        
        self.A = [unitA] * (self.n * 2)
        
    def update(self, i, x):
        i += self.n
        self.propagate_above(i)
        self.X[i] = x
        self.A[i] = unitA
        self.calc_above(i)
    
    def calc(self, i):
        return self.g(self.X[i], self.A[i], self.size[i])
    
    def calc_above(self, i):
        i >>= 1
        while i:
            self.X[i] = self.f(self.calc(i*2), self.calc(i*2|1))
            i >>= 1
    
    def propagate(self, i):
        self.X[i] = self.g(self.X[i], self.A[i], self.size[i])
        self.A[i*2] = self.h(self.A[i*2], self.A[i])
        self.A[i*2|1] = self.h(self.A[i*2|1], self.A[i])
        self.A[i] = self.unitA
        
    def propagate_above(self, i):
        H = i.bit_length()
        for h in range(H, 0, -1):
            self.propagate(i >> h)
    
    def propagate_all(self):
        for i in range(1, self.n):
            self.propagate(i)
    
    def getrange(self, l, r):
        l += self.n
        r += self.n
        l0, r0 = l // (l & -l), r // (r & -r) - 1
        self.propagate_above(l0)
        self.propagate_above(r0)
        
        al = self.unitX
        ar = self.unitX
        while l < r:
            if l & 1:
                al = self.f(al, self.calc(l))
                l += 1
            if r & 1:
                r -= 1
                ar = self.f(self.calc(r), ar)
            l >>= 1
            r >>= 1
        return self.f(al, ar)
    
    def getrange_l(self, r):
        if r == self.n: return self.calc(1)
        r += self.n
        r //= r & -r
        r0 = r
        self.propagate_above(r0)
        ar = self.unitX
        while r > 1:
            r -= 1
            ar = self.f(self.calc(r), ar)
            r //= r & -r
        return ar
    
    def getvalue(self, i):
        i += self.n
        self.propagate_above(i)
        return self.calc(i)
    
    def operate_range(self, l, r, a):
        l += self.n
        r += self.n
        l0, r0 = l // (l & -l), r // (r & -r) - 1
        self.propagate_above(l0)
        self.propagate_above(r0)
        while l < r:
            if l & 1:
                self.A[l] = self.h(self.A[l], a)
                l += 1
            if r & 1:
                r -= 1
                self.A[r] = self.h(self.A[r], a)
            l >>= 1
            r >>= 1
        
        self.calc_above(l0)
        self.calc_above(r0)
    
    def operate_range_l(self, r, a):
        if r == self.n:
            self.A[1] = self.h(self.A[1], a)
            return
        r += self.n
        r //= r & -r
        r0 = r - 1
        self.propagate_above(r0)
        while r > 1:
            r -= 1
            self.A[r] = self.h(self.A[r], a)
            r //= r & -r
        
        self.calc_above(r0)
    
    def operate_range_r(self, l, a):
        if l >= self.n: return
        if not l:
            self.A[1] = self.h(self.A[1], a)
            return
        l += self.n
        l //= l & -l
        l0 = l
        self.propagate_above(l0)
        while l > 1:
            self.A[l] = self.h(self.A[l], a)
            l += 1
            l //= l & -l
        self.calc_above(l0)

    def check(self, randX, randA, maxs, rep):
        from random import randrange
        f, g, h = self.f, self.g, self.h
        for _ in range(rep):
            x = randX()
            y = randX()
            z = randX()
            a = randA()
            b = randA()
            c = randA()
            s = randrange(1, maxs + 1)
            t = randrange(1, maxs + 1)
            err = 0
            if not f(x, unitX) == f(unitX, x) == x:
                err = 1
                print("!!!!! unitX Error !!!!!")
                print("unitX =", unitX)
                print("x =", x)
                print("f(x, unitX) =", f(x, unitX))
                print("f(unitX, x) =", f(unitX, x))
            
            if not h(a, unitA) == h(unitA, a) == a:
                err = 1
                print("!!!!! unitA Error !!!!!")
                print("unitA =", unitA)
                print("a =", a)
                print("h(a, unitA) =", h(a, unitA))
                print("h(unitA, a) =", h(unitA, a))
                
            if not f(f(x, y), z) == f(x, f(y, z)):
                err = 1
                print("!!!!! Associativity Error X !!!!!")
                print("x, y, z, f(x, y), f(y, x) =", x, y, z, f(x, y), f(y, x))
                print("f(f(x, y), z) =", f(f(x, y), z))
                print("f(x, f(y, z)) =", f(x, f(y, z)))
            
            if not h(h(a, b), c) == h(a, h(b, c)):
                err = 1
                print("!!!!! Associativity Error A !!!!!")
                print("a, b, c, h(a, b), h(b, c) =", a, b, c, h(a, b), h(b, c))
                print("h(h(a, b), c) =", h(h(a, b), c))
                print("h(a, h(b, c)) =", h(a, h(b, c)))
            
            if not g(x, unitA, s) == x:
                err = 1
                print("!!!!! Identity Error !!!!!")
                print("unitA, x, s =", unitA, x, s)
                print("g(x, unitA, s) =", g(x, unitA, s))
            
            if not g(g(x, a, s), b, s) == g(x, h(a, b), s):
                err = 1
                print("!!!!! Act Error 1 !!!!!")
                print("x, a, b, s, g(x, a, s), h(a, b) =", x, a, b, s, g(x, a, s), h(a, b))
                print("g(g(x, a, s), b, s) =", g(g(x, a, s), b, s))
                print("g(x, h(a, b), s)    =", g(x, h(a, b), s))
            
            if not g(f(x, y), a, s + t) == f(g(x, a, s), g(y, a, t)):
                err = 1
                print("!!!!! Act Error 2 !!!!!")
                print("x, y, a, s, t, f(x, y), g(x, a, s), g(y, a, t) =", x, y, a, s, t, f(x, y), g(x, a, s), g(y, a, t))
                print("g(f(x, y), a, s + t)      =", g(f(x, y), a, s + t))
                print("f(g(x, a, s), g(y, a, t)) =", f(g(x, a, s), g(y, a, t)))
            
            if err:
                break
                assert f(x, unitX) == f(unitX, x) == x
                assert h(a, unitA) == h(unitA, a) == a
                assert f(f(x, y), z) == f(x, f(y, z))
                assert h(h(a, b), c) == h(a, h(b, c))
                assert g(x, unitA, s) == x
                assert g(g(x, a, s), b, s) == g(x, h(a, b), s)
                assert g(f(x, y), a, s + t) == f(g(x, a, s), g(y, a, t))
        else:
            pass
            print("Monoid Check OK!")
    
    def debug1(self):
        print("self.n =", self.n)
        deX = []
        deA = []
        deS = []
        a, b = self.n, self.n * 2
        while b:
            deX.append(self.X[a:b])
            deA.append(self.A[a:b])
            deS.append(self.size[a:b])
            a, b = a//2, a
        print("--- debug ---")
        for d in deX[::-1]:
            print(d)
        print("--- ---")
        for d in deA[::-1]:
            print(d)
        print("--- ---")
        for d in deS[::-1]:
            print(d)
        print("--- ---")
    
    def debug(self, k = 10):
        print("--- debug ---")
        print("point")
        for i in range(min(self.n - 1, k)):
            print(i, self.getvalue(i))
        print("prod")
        for i in range(min(self.n, k)):
            print(i, self.getrange(0, i))
        print("--- ---")
    
    def debug(self):
        print([self.getvalue(i) for i in range(self.n)])

import sys
input = lambda: sys.stdin.readline().rstrip()

g = lambda x, a, s: max(x, a)
f = max
h = max
unitX = 0
unitA = 0
W, N = map(int, input().split())
st = LazySegmentTree(W, unitX, unitA, f, g, h)

ANS = []
for _ in range(N):
    l, r = map(int, input().split())
    l -= 1
    a = st.getrange(l, r)
    st.operate_range(l, r, a + 1)
    ANS.append(str(a + 1))

print("\n".join(ANS))

035 - Preserve Connectivity(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:4370 Byte
  • 実行時間:486 ms
import sys
input = lambda: sys.stdin.readline().rstrip()

class SegmentTree():
    def __init__(self, init, unitX, f):
        self.f = f # (X, X) -> X
        self.unitX = unitX
        self.f = f
        if type(init) == int:
            self.n = init
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * (self.n * 2)
        else:
            self.n = len(init)
            self.n = 1 << (self.n - 1).bit_length()
            self.X = [unitX] * self.n + init + [unitX] * (self.n - len(init))
            for i in range(self.n-1, 0, -1):
                self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
        
    def update(self, i, x):
        i += self.n
        self.X[i] = x
        i >>= 1
        while i:
            self.X[i] = self.f(self.X[i*2], self.X[i*2|1])
            i >>= 1
    
    def getvalue(self, i):
        return self.X[i + self.n]
    
    def getrange(self, l, r):
        l += self.n
        r += self.n
        al = self.unitX
        ar = self.unitX
        while l < r:
            if l & 1:
                al = self.f(al, self.X[l])
                l += 1
            if r & 1:
                r -= 1
                ar = self.f(self.X[r], ar)
            l >>= 1
            r >>= 1
        return self.f(al, ar)
    
    # Find r s.t. calc(l, ..., r-1) = True and calc(l, ..., r) = False
    def max_right(self, l, z):
        if l >= self.n: return self.n
        l += self.n
        s = self.unitX
        while 1:
            while l % 2 == 0:
                l >>= 1
            if not z(self.f(s, self.X[l])):
                while l < self.n:
                    l *= 2
                    if z(self.f(s, self.X[l])):
                        s = self.f(s, self.X[l])
                        l += 1
                return l - self.n
            s = self.f(s, self.X[l])
            l += 1
            if l & -l == l: break
        return self.n
    
    # Find l s.t. calc(l, ..., r-1) = True and calc(l-1, ..., r-1) = False
    def min_left(self, r, z):
        if r <= 0: return 0
        r += self.n
        s = self.unitX
        while 1:
            r -= 1
            while r > 1 and r % 2:
                r >>= 1
            if not z(self.f(self.X[r], s)):
                while r < self.n:
                    r = r * 2 + 1
                    if z(self.f(self.X[r], s)):
                        s = self.f(self.X[r], s)
                        r -= 1
                return r + 1 - self.n
            s = self.f(self.X[r], s)
            if r & -r == r: break
        return 0
    
    def debug(self):
        print("debug")
        print([self.getvalue(i) for i in range(min(self.n, 20))])

N = int(input())
X = [[] for i in range(N)]
for i in range(N-1):
    x, y = map(int, input().split())
    x, y = x-1, y-1
    X[x].append(y)
    X[y].append(x)

def EulerTour(n, X, i0):
    # Xは破壊してXとPができる
    P = [-1] * n
    Q = [~i0, i0]
    ct = -1
    ET = []
    ET1 = [0] * n
    ET2 = [0] * n
    DE = [0] * n
    de = -1
    while Q:
        i = Q.pop()
        if i < 0:
            # ↓ 戻りも数字を足す場合はこれを使う
            ct += 1
            # ↓ 戻りもETに入れる場合はこれを使う
            ET.append(P[~i])
            ET2[~i] = ct
            de -= 1
            continue
        if i >= 0:
            ET.append(i)
            ct += 1
            if ET1[i] == 0: ET1[i] = ct
            de += 1
            DE[i] = de
        for a in X[i][::-1]:
            if a != P[i]:
                P[a] = i
                for k in range(len(X[a])):
                    if X[a][k] == i:
                        del X[a][k]
                        break
                Q.append(~a)
                Q.append(a)
    return (ET, ET1, ET2, DE)

ET, ET1, ET2, DE = EulerTour(N, X, 0)
ET.pop()
X = [DE[a] for a in ET] + [1 << 17]

f = min
unit = 1 << 17
st = SegmentTree(X, unit, f)
Q = int(input())
for _ in range(Q):
    k, *V = map(int, input().split())
    V = sorted([ET1[a-1] for a in V])
    a, b = V[0], V[-1]
    ans = X[a] + X[b] - st.getrange(a, b + 1) * 2
    for a, b in zip(V, V[1:]):
        ans += X[a] + X[b] - st.getrange(a, b + 1) * 2
    
    print(ans // 2)

040 - Get More Money(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:2848 Byte
  • 実行時間:348 ms
from collections import deque
class maxflow():
    def __init__(self, n):
        self.pos = []
        self.g = [[] for _ in range(n)]
        self.n = n
        self.cap = {}
   
    def add_edge(self, v, w, cap):
        e1 = [w, 0, cap]
        e2 = [v, e1, 0]
        e1[1] = e2
        self.g[v].append(e1)
        self.g[w].append(e2)
        self.cap[(v, w)] = [cap, e1]
    
    def bfs(self, s, t):
        L = [-1] * self.n
        L[s] = 0
        Q = deque([s])
        while Q:
            v = Q.popleft()
            for w, _, cap in self.g[v]:
                if cap == 0 or L[w] >= 0: continue
                L[w] = L[v] + 1
                if w == t:
                    self.L = L
                    return
                Q.append(w)
        self.L = L

    def dfs(self, s, v, up):
        if v == s: return up
        res = 0
        lv = self.L[v]
        for i in range(self.it[v], len(self.g[v])):
            w, rev, cap = self.g[v][i]
            if lv > self.L[w] and rev[2] > 0:
                d = self.dfs(s, w, min(up - res, rev[2]))
                if d > 0: 
                    self.g[v][i][2] += d
                    rev[2] -= d
                    res += d
                    if res == up: break
            self.it[v] += 1
        return res
    
    def flow(self, s, t, flow_limit = -1):
        if flow_limit < 0: flow_limit = 10 ** 100
        
        flow = 0
        while flow < flow_limit:
            self.bfs(s, t)
            if self.L[t] == -1: break
            self.it = [0] * self.n
            while flow < flow_limit:
                f = self.dfs(s, t, flow_limit - flow)
                if not f: break
                flow += f
        return flow, self.g
    
    def min_cut(self, s):
        visited = [0] * self.n
        Q = [s]
        while Q:
            p = Q.pop()
            visited[p] = 1
            for e in self.g[p]:
                if e[2] and visited[e[0]] == 0:
                    visited[e[0]] = 1
                    Q.append(e[0])
        return visited
    
    def flow_all(self):
        L = []
        for a in self.cap:
            v, w = a
            c, e = self.cap[a]
            L.append((v, w, c - e[2]))
        return L

import sys
input = sys.stdin.readline
inf = 10 ** 10
N, W = map(int, input().split())
A = [int(a) for a in input().split()]
mf = maxflow(N + 2)
s = N
t = s + 1
ans = 0
for i, a in enumerate(A):
    if a > W:
        mf.add_edge(i, t, a - W)
        ans += a - W
    elif W > a:
        mf.add_edge(s, i, W - a)

for i in range(N):
    k, *C = map(int, input().split())
    for j in C:
        mf.add_edge(i, j-1, inf)
    
fl, g = mf.flow(s, t)
# print("ans =", ans)
# print("fl =", fl)
# print("g =", g)
print(ans - fl)

041 - Piles in AtCoder Farm(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:1545 Byte
  • 実行時間:354 ms
def gcd(a, b):
    while b: a, b = b, a % b
    return abs(a)

def graham_scan(X):
    # Assuming 0 <= x, y <= 10^9
    mi = 10 ** 18
    for x, y in X:
        mi = min(mi, (y << 30) + x)
    x0, y0 = mi % (1 << 30), mi >> 30
    SX = sorted([(x - x0, y - y0) for x, y in X if x != x0 or y != y0], key = lambda a: -a[0] / a[1] if a[1] else -10 ** 10)
    RE = [(0, 0)]
    for x, y in SX:
        RE.append((x, y))
        while len(RE) > 2:
            xx, yy = RE[-1]
            px, py = RE[-2]
            ppx, ppy = RE[-3]
            t = (px - ppx) * (yy - ppy) - (py - ppy) * (xx - ppx)
            # print("RE, t =", RE, t)
            if t < 0:
                la = RE.pop()
                RE[-1] = la
            elif t == 0:
                if len(RE) == 3:
                    if abs(RE[-1][0]) + abs(RE[-1][1]) > abs(RE[-2][0]) + abs(RE[-2][1]):
                        la = RE.pop()
                        RE[-1] = la
                    else:
                        RE.pop()
                else:
                    la = RE.pop()
                    RE[-1] = la
            else:
                break
    # return (x0, y0), RE
    
    s = 0
    cc = 0
    RE.append((0, 0))
    for (px, py), (x, y) in zip(RE, RE[1:]):
        s += (y + py) * (px - x)
        cc += gcd(x - px, y - py)
    return (s + cc) // 2 + 1 - N
    # return s, c, cc
    

N = int(input())
X = []
for _ in range(N):
    x, y = map(int, input().split())
    X.append((x, y))

print(graham_scan(X))

045 - Simple Grouping(★6)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
彩色数の考え方を使っています。 こちらの資料 の17ページ以降ぐらいに詳しい解説があります。

  • 使用言語:PyPy3 (7.3.0)
  • コード長:1087 Byte
  • 実行時間:98 ms
N, K = map(int, input().split())
X = []
for _ in range(N):
    x, y = map(int, input().split())
    X.append((x, y))
B = {1 << i: i for i in range(N)}
PCP = [1] * (1 << N)
for i in range(1, 1 << N):
    a = i & -i
    PCP[i] = - PCP[i^a]

S = {0}
for i in range(N):
    x, y = X[i]
    for j in range(i):
        xx, yy = X[j]
        S.add((x - xx) ** 2 + (y - yy) ** 2)
SS = sorted(S)

def chk(t):
    k = SS[t]
    I = [0] * N
    for i in range(N):
        x, y = X[i]
        for j in range(N):
            if i == j: continue
            xx, yy = X[j]
            if (x - xx) ** 2 + (y - yy) ** 2 < k:
                I[i] ^= 1 << j
    Y = [0] * (1 << N)
    Y[0] = 1
    for i in range(1, 1 << N):
        a = i & -i
        if i == a:
            Y[i] = 2
            continue
        Y[i] = Y[i^a] + Y[i&I[B[a]]]
    re = 0
    for y, p in zip(Y, PCP):
        re += p * y ** K
    return 0 if re else 1

l, r = 0, len(SS)
while r - l > 1:
    m = l + r >> 1
    if chk(m):
        l = m
    else:
        r = m
print(SS[l])


047 - Monochromatic Diagonal(★7)

問題

AC コード(コンテスト中は飛べません)

ネタバレ

  • FFT をします。
  • \mod p 上で a^{2}=bc,\ b^{2}=ac,\ c^{2}=ab (★)となるもの( 0\lt a\lt b\lt c\lt p )を探します。
  • "R"、"G"、"B" を a,\ b,\ c にそれぞれ対応させて、 T の方をひっくり返して FFT します。
  • ある場所が斜めすべて "R" になっている場合、該当する項は 個数× a^{2} になります。
  • "G" 、 "B" も同様です。
  • p を十分大きく取れば、多くの場合、逆も成り立ちます。
  • あとは ★ を満たす a,\ b,\ c,\ p を探せばよいです。
  • p を決めると、 a,\ b,\ c は適当にやると求まりそうです(どうせ埋め込むので O(p) かけても良いし、離散対数などでも高速にできます)。
  • ただし ★ から \mod pa^{3} = b^{3} = c^{3} が成立することが必要なので、 p\equiv 1\ ({\mathrm {mod}}\ 3) を満たす p を選ばないといけないことに注意です。
  • (なお逆に p\equiv 1\ ({\mathrm {mod}}\ 3) を満たす素数 p および 0 \lt a \lt p を決めると b,\ c の組は必ず存在して、順番の入れ替えを除いて一意に定まります。)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:2049 Byte
  • 実行時間:1102 ms
p, g, ig = 1012924417, 5, 405169767
W = [pow(g, (p - 1) >> i, p) for i in range(24)]
iW = [pow(ig, (p - 1) >> i, p) for i in range(24)]

def convolve(a, b):
    def fft(f):
        for l in range(k, 0, -1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * W[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t]) % p, U[j] * (f[s] - f[t]) % p

    def ifft(f):
        for l in range(1, k + 1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * iW[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t] * U[j]) % p, (f[s] - f[t] * U[j]) % p

    n0 = len(a) + len(b) - 1
    if len(a) < 50 or len(b) < 50:
        ret = [0] * n0
        if len(a) > len(b): a, b = b, a
        for i, aa in enumerate(a):
            for j, bb in enumerate(b):
                ret[i+j] = (ret[i+j] + aa * bb) % p
        return ret
    
    k = (n0).bit_length()
    n = 1 << k
    a = a + [0] * (n - len(a))
    b = b + [0] * (n - len(b))
    fft(a), fft(b)
    for i in range(n):
        a[i] = a[i] * b[i] % p
    ifft(a)
    invn = pow(n, p - 2, p)
    for i in range(n0):
        a[i] = a[i] * invn % p
    del a[n0:]
    return a

P = 1012924417
s = 12357
t = 494461391
u = 518450669
ss = 152695449
tt = 760903725
uu = 99325243
N = int(input())
S = input()
T = input()[::-1]

ans = 0

A = [s if a == "R" else t if a == "G" else u for a in S]
B = [s if a == "R" else t if a == "G" else u for a in T]
C = convolve(A, B)
for i, c in enumerate(C):
    k = i + 1 if i < N else 2 * N - 1 - i
    if c == k * ss % P or c == k * tt % P or c == k * uu % P:
        ans += 1

print(ans)

051 - Typical Shop(★5)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
想定解ほぼそのまま(半分全列挙)だけど、 log を落として O(2^{N/2}) にしています。

  • 使用言語:PyPy3 (7.3.0)
  • コード長:1107 Byte
  • 実行時間:321 ms
def subset_sum_list(L):
    n = len(L)
    D = [0]
    for a in L:
        i = 0
        nD = []
        for d in D:
            while i < len(D) and D[i] < d + a:
                # ↓if 条件を消すと重複込みで列挙(マイナスがあるときは全列挙のみ)
                # if len(nD) == 0 or D[i] > nD[-1]: nD.append(D[i])
                nD.append(D[i])
                i += 1
            nD.append(d + a)
        D = nD
    return D

N, K, P = map(int, input().split())
M = N // 2
A = [int(a) * 50 + 1 for a in input().split()]
B = sorted(A[:M])
C = sorted(A[M:])

SB = subset_sum_list(B)
SC = subset_sum_list(C)
X = [[] for _ in range(51)]
Y = [[] for _ in range(51)]
for s in SB:
    x = s // 50
    c = s % 50
    X[c].append(x)
for s in SC:
    x = s // 50
    c = s % 50
    Y[c].append(x)

ans = 0
for k1 in range(min(K, 20) + 1):
    k2 = K - k1
    x = X[k1][::-1]
    y = [-10 ** 18] + Y[k2]
    mm = len(y)
    j = 0
    for b in x:
        while j < mm - 1 and y[j+1] + b <= P:
            j += 1
        ans += j
print(ans)

053 - Discrete Dowsing(★7)

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:853 Byte
  • 実行時間:118 ms
from math import sqrt
t = (3 - sqrt(5)) / 2
def f(n):
    if X[n] >= 0: return X[n]
    print("?", n)
    X[n] = int(input())
    return X[n]

T = int(input())
for _ in range(T):
    N = int(input())
    X = [-1] * (N + 1)
    if N < 9:
        print("!", max(f(i) for i in range(1, N + 1)))
        continue
    a, d = 0, N + 1
    b = a + int((d - a) * t + 0.5)
    c = d - int((d - a) * t + 0.5)
    fb = f(b)
    fc = f(c)
    
    while d - a > 4:
        if fb < fc:
            a = b
            b = c
            fb = fc
            c = d - int((d - a) * t + 0.5)
            fc = f(c)
        else:
            d = c
            c = b
            fc = fb
            b = a + int((d - a) * t + 0.5)
            fb = f(b)
    
    print("!", max([f(i) for i in range(a + 1, d) if i != b and i != c] + [fb, fc]))

059 - Many Graph Queries(★7)

問題

AC コード(コンテスト中は飛べません)

ネタバレ

  • BIT 演算で高速化
  • 何 bit ごとにやるかで若干速度変わるかも(18 行目の K の設定)
  • 不要なものを飛ばして工夫するともう少し(定数倍)高速化できるはずだけどめんどいのでやってない

  • 使用言語:PyPy3 (7.3.0)
  • コード長:883 Byte
  • 実行時間:1288 ms
import sys
input = lambda: sys.stdin.readline().rstrip()
N, M, Q = map(int, input().split())
E = [[] for _ in range(N)]
F = [[] for _ in range(N)]
for _ in range(M):
    x, y = map(int, input().split())
    x, y = x-1, y-1
    E[x].append(y)
    F[y].append(x)

X = [[] for _ in range(N)]
for i in range(Q):
    a, b = map(int, input().split())
    a, b = a-1, b-1
    X[a].append((b, i))
ANS = [0] * Q
K = (Q + 3) // 4
for i in range((Q + K - 1) // K):
    Z = [0] * N
    s = 0
    for j in range(K):
        if i * K + j < N:
            Z[i*K+j] = 1 << j
    for k in range(i * K, N):
        for l in F[k]:
            Z[k] |= Z[l]
    for j in range(K):
        a = i * K + j
        if a < N:
            for b, ii in X[a]:
                if Z[b] >> j & 1:
                    ANS[ii] = 1
print("\n".join(["Yes" if a else "No" for a in ANS]))

063 - Monochromatic Subgrid(★4)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
ほぼ想定解。ただし bit 演算で一部高速化している。

  • 使用言語:PyPy3 (7.3.0)
  • コード長:761 Byte
  • 実行時間:143 ms
H, W = map(int, input().split())
X = [[int(a) - 1 for a in input().split()] for _ in range(H)]
Z = [[0] * H for _ in range(W)]
for i, z in enumerate(Z):
    for j in range(H - 1, -1, -1):
        t = X[j][i]
        s = 1 << j
        for k in range(j + 1, H):
            if X[k][i] == t:
                s ^= z[k]
                break
        z[j] = s

D = [0] * (H * W)
ans = 0
for i in range(1, 1 << H):
    pc = bin(i).count("1")
    l = (i & -i).bit_length() - 1
    x = X[l]
    ma = 0
    L = []
    for j, z in enumerate(Z):
        if z[l] & i == i:
            a = x[j]
            D[a] += 1
            ma = max(ma, D[a])
            L.append(a)
    ans = max(ans, ma * pc)
    for a in L:
        D[a] -= 1
print(ans)

065 - RGB Balls 2(★7)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
想定解(FFT

  • 使用言語:PyPy3 (7.3.0)
  • コード長:2189 Byte
  • 実行時間:766 ms
p, g, ig = 998244353, 3, 332748118
W = [pow(g, (p - 1) >> i, p) for i in range(24)]
iW = [pow(ig, (p - 1) >> i, p) for i in range(24)]

def convolve(a, b):
    def fft(f):
        for l in range(k, 0, -1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * W[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t]) % p, U[j] * (f[s] - f[t]) % p

    def ifft(f):
        for l in range(1, k + 1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * iW[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t] * U[j]) % p, (f[s] - f[t] * U[j]) % p

    n0 = len(a) + len(b) - 1
    if len(a) < 50 or len(b) < 50:
        ret = [0] * n0
        if len(a) > len(b): a, b = b, a
        for i, aa in enumerate(a):
            for j, bb in enumerate(b):
                ret[i+j] = (ret[i+j] + aa * bb) % p
        return ret
    
    k = (n0).bit_length()
    n = 1 << k
    a = a + [0] * (n - len(a))
    b = b + [0] * (n - len(b))
    fft(a), fft(b)
    for i in range(n):
        a[i] = a[i] * b[i] % p
    ifft(a)
    invn = pow(n, p - 2, p)
    for i in range(n0):
        a[i] = a[i] * invn % p
    del a[n0:]
    return a

def pol(a, m):
    L = [0] * (m + 1)
    for i in range(a, m + 1):
        L[i] = C(m, i)
    return L

P = 998244353
nn = 1001001

fa = [1] * (nn+1)
fainv = [1] * (nn+1)
for i in range(nn):
    fa[i+1] = fa[i] * (i+1) % P
fainv[-1] = pow(fa[-1], P-2, P)
for i in range(nn)[::-1]:
    fainv[i] = fainv[i+1] * (i+1) % P

C = lambda a, b: fa[a] * fainv[b] % P * fainv[a-b] % P if 0 <= b <= a else 0

a, b, c, K = map(int, input().split())
X, Y, Z = map(int, input().split())
x, y, z = K - Y, K - Z, K - X
A, B, C = pol(x, a), pol(y, b), pol(z, c)
print(convolve(convolve(A, B), C)[K])

066 - Various Arrays(★5)

問題

AC コード(コンテスト中は飛べません)

※ 下の解法は 発展版 にも対応しています

解法ネタバレ
大きい方から見て BIT (座圧不要)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:1237 Byte
  • 実行時間:66 ms
class BIT():
    def __init__(self, init):
        if type(init) == int:
            self.n = init + 1
            self.X = [0] * self.n
        else:
            self.n = len(init) + 1
            self.X = [0] + init
            for i in range(1, self.n):
                if i + (i & -i) < self.n:
                    self.X[i + (i & -i)] += self.X[i]
            
    def add(self, i, x=1):
        i += 1
        while i < self.n:
            self.X[i] += x
            i += i & (-i)

    def getsum(self, i):
        ret = 0
        while i != 0:
            ret += self.X[i]
            i -= i&(-i)
        return ret

    def getrange(self, l, r):
        return self.getsum(r) - self.getsum(l)

N = int(input())
X = []
for i in range(N):
    l, r = map(int, input().split())
    l -= 1
    X.append((i, 1 / (r - l), l))
    X.append((i, -1 / (r - l), r))

X.sort(key = lambda x: -x[2])

bit0 = BIT(N)
bit1 = BIT(N)
bit2 = BIT(N)

ans = 0
for i, t, v in X:
    c = bit0.getsum(i)
    s = bit1.getsum(i)
    ss = bit2.getsum(i)
    ans += - t * (ss - (s - c * v) * v - c * v * (v - 1) / 2)
    bit0.add(i, t)
    bit1.add(i, t * v)
    bit2.add(i, t * v * (v - 1) / 2)

print(ans)

085 - Multiplication 085(★4)

問題

AC コード(コンテスト中は飛べません)

ネタバレ
O(K^{1/4})素因数分解 して素因数ごとに計算(基本は指数を e として (e+1)(e+2)/2 通りだけど重複を除く)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:2553 Byte
  • 実行時間:63 ms
def gcd(a, b):
    while b: a, b = b, a % b
    return a
def isPrimeMR(n):
    d = n - 1
    d = d // (d & -d)
    L = [2, 7, 61] if n < 1<<32 else [2, 3, 5, 7, 11, 13, 17] if n < 1<<48 else [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
    for a in L:
        t = d
        y = pow(a, t, n)
        if y == 1: continue
        while y != n - 1:
            y = y * y % n
            if y == 1 or t == n - 1: return 0
            t <<= 1
    return 1
def findFactorRho(n):
    m = 1 << n.bit_length() // 8
    for c in range(1, 99):
        f = lambda x: (x * x + c) % n
        y, r, q, g = 2, 1, 1, 1
        while g == 1:
            x = y
            for i in range(r):
                y = f(y)
            k = 0
            while k < r and g == 1:
                ys = y
                for i in range(min(m, r - k)):
                    y = f(y)
                    q = q * abs(x - y) % n
                g = gcd(q, n)
                k += m
            r <<= 1
        if g == n:
            g = 1
            while g == 1:
                ys = f(ys)
                g = gcd(abs(x - ys), n)
        if g < n:
            if isPrimeMR(g): return g
            elif isPrimeMR(n // g): return n // g
            return findFactorRho(g)
def primeFactor(n):
    i = 2
    ret = {}
    rhoFlg = 0
    while i * i <= n:
        k = 0
        while n % i == 0:
            n //= i
            k += 1
        if k: ret[i] = k
        i += i % 2 + (3 if i % 3 == 1 else 1)
        if i == 101 and n >= 2 ** 20:
            while n > 1:
                if isPrimeMR(n):
                    ret[n], n = 1, 1
                else:
                    rhoFlg = 1
                    j = findFactorRho(n)
                    k = 0
                    while n % j == 0:
                        n //= j
                        k += 1
                    ret[j] = k

    if n > 1: ret[n] = 1
    if rhoFlg: ret = {x: ret[x] for x in sorted(ret)}
    return ret
def divisors(N):
    pf = primeFactor(N)
    ret = [1]
    for p in pf:
        ret_prev = ret
        ret = []
        for i in range(pf[p]+1):
            for r in ret_prev:
                ret.append(r * (p ** i))
    return sorted(ret)

K = int(input())
D = primeFactor(K)
E = list(D.values())
aaa = 1
for e in E:
    if e % 3: aaa = 0
aab = 1
for e in E:
    aab *= (e // 2 + 1)
aab -= aaa
abc = 1
for e in E:
    abc *= (e + 1) * (e + 2) // 2
abc -= aaa
abc -= aab * 3
abc //= 6
print(aaa + aab + abc)

000 - Dummy

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:*** Byte
  • 実行時間:*** ms
# だみー

000 - Dummy

問題

AC コード(コンテスト中は飛べません)

  • 使用言語:PyPy3 (7.3.0)
  • コード長:*** Byte
  • 実行時間:*** ms
# だみー

000 - Dummy

問題

AC コード(コンテスト中は飛べません)

ネタバレ
ここに解法を書く

  • 使用言語:PyPy3 (7.3.0)
  • コード長:*** Byte
  • 実行時間:*** ms
# だみー

あとで何か書くかも

あとで何か書くかも