트리에서 두 노드의 가장 가까운 공통 조상을 보다 효율적으로 구하기 위한 LCA 알고리즘에 대해 관련 문제를 해결하며 알아본다. 

 

1. 가장 가까운 공통 조상(https://www.acmicpc.net/problem/3584)

 가장 기본적인 LCA 문제이다.  노드의 수가 최대 10,000이며 쿼리의 수도 하나 뿐이기 때문에 효율적인 알고리즘을 생각할 필요 없이 간선 정보를 토대로 각 노드의 부모노드를 저장하고 쿼리로 주어진 두 노드로부터 부모노드를 탐색해나가며 각 노드의 모든 조상노드 리스트를 구하고, 두 리스트를 뒤에서부터 비교해나갔을 때 마지막으로 같았던 두 리스트의 값이 가장 가까운 공통 조상 노드가 된다.  이를 구현한 코드는 다음과 같다.

import sys

input = sys.stdin.readline


# 3584 가장 가까운 공통 조상
# 노드의 수가 적고 쿼리도 케이스당 하나 뿐이기 때문에
# 단순한 방식으로도 해결 가능
def sol3584():
    # 케이스별 정답 리스트
    answer = []
    for t in range(int(input())):
        # 노드의 수
        n = int(input())

        # 각 노드의 부모노드
        parent = [0] * (n + 1)

        # 부모노드와 자식노드를 입력받아 부모노드를 채운다
        for _ in range(n - 1):
            a, b = map(int, input().split())
            parent[b] = a

        # 가장 가까운 공통조상을 구할 두 노드
        u, v = map(int, input().split())

        # 두 노드의 조상 노드 리스트를 구한다.
        up, vp = [], []
        while u:
            up.append(u)
            u = parent[u]
        while v:
            vp.append(v)
            v = parent[v]

        # 두 노드의 조상 노드 리스트를 뒤집는다.
        up, vp = up[::-1], vp[::-1]

        # 루트 노드부터 시작하여 두 노드의 조상 노드가 서로 달라질 때 까지 탐색
        # 두 노드의 조상 노드가 마지막으로 같았을 때의 조상 노드가 가장 가까운 공통 조상 노드가 된다.
        res = 0
        for i in range(min(len(up), len(vp))):
            if up[i] == vp[i]:
                res = up[i]
            else:
                break

        # 정답 리스트에 결과 삽입
        answer.append(res)

    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 

하지만 위의 방법은 당연하게도 노드의 수와 쿼리의 수가 많아지면 쿼리마다 사용할 수 없다.

 

 

이번엔 다른 방식으로 문제를 해결해보자.  절차는 다음과 같다.

 

① 먼저 트리를 탐색하여 모든 노드의 깊이와 부모노드를 구한다.

② 주어진 두 노드의 깊이가 다르다면 보다 깊이 있는 노드가 다른 노드와 같은 높이가 될 때 까지 거슬러 올라온다.

③ 두 노드가 서로 같아질 때 까지 두 노드 모두 부모노드를 타고 거슬러 올라온다.

④ 두 노드가 서로 같아졌을 때, 그 노드가 가장 가까운 조상 노드가 된다.

 

이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol3584():
    # 정답 리스트
    answer = []
    
    # 테스트케이스
    for t in range(int(input())):
        # 노드의 수
        n = int(input())
        
        # 그래프 생성, 부모 노드 구하기
        g = [[] for _ in range(n + 1)]
        parent = [0] * (n + 1)
        for _ in range(n-1):
            a, b = map(int, input().split())
            g[a].append(b)
            parent[b] = a
        
        # 루트 노드로부터 bfs탐색으로 각 노드의 깊이를 구한다
        depth = [0] * (n + 1)
        root = [i for i in range(1, n+1) if not parent[i]][0]
        
        q = [root]
        while q:
            nq = []
            for cur in q:
                for c in g[cur]:
                    depth[c] = depth[cur] + 1
                    nq.append(c)
            q = nq
            
        # 주어진 두 노드 u, v
        u, v = map(int, input().split())
        
        # 노드 u의 깊이가 노드 v의 깊이보다 크도록 한다.
        if depth[u] < depth[v]:
            u, v = v, u
            
        # 두 노드의 깊이차가 0이 될 때 까지 깊은 쪽 노드인 u를 끌어올린다
        while depth[u] - depth[v]:
            u = parent[u]
           
        # 두 노드가 서로 같아질 때 까지 두 노드를 끌어올린다
        while u != v:
            u, v = parent[u], parent[v]
          
        # u와 v가 서로 같아지면 u == v == 가장 가까운 공통 조상 노드 가 된다.
        answer.append(u)
       
    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

이 방식 자체로는 보다 많은 노드와 쿼리 수가 주어지는 문제를 해결하기에 딱히 이전의 알고리즘보다 효율적이지 않다.  하지만 이 방식에 sparse table을 적용하면 얘기가 달라진다.  다음 문제에서 이 sparse table에 대해 다룬다.

 

2. 합성함수와 쿼리(https://www.acmicpc.net/problem/17435)

 함수 f(x)의 값들이 주어졌을때, fn(x) 가 f(f(f(...f(x)...))) 와 같이 함수 f를 n중첩한 결과라고 한다.  이 때, 주어진 m개의

n, x 쌍에 대해 fn(x) 를 각각 구하는 문제이다.  LCA 문제는 아니지만 LCA 알고리즘의 핵심인 sparse table을 응용해서

풀 수 있는 문제이다. 절차는 다음과 같다.

 

① 먼저 함수 f의 값을 저장할 2차원 배열을 생성한다.  이 때, f[x][k] 는 함수 f의 k중첩에 x를 넣은 결과이다.  예를들어,

    f[3][4] = f(f(f(f(3)))) 이 된다.

 

② f[x][0] 의 값을 입력받은 f(x)의 값으로 초기화한다.

 

③ f[i][j] 은 f[i][j-1] 을 2^(j-1) 중첩한 값이 되기 때문에, 점화식 f[i][j] = f[f[i][j-1]][j-1] 이 성립한다. 이 점화식에 따라 

    나머지 f[x][1~k] 를 채운다.  여기서 k는 함수 중첩수 최댓값인 n에 로그를 취한 뒤 올림한 값이 된다. 

    => math.ceil(math.log2(n))

 

④ 이제 쿼리로 n, x가 주어지면 ③ 에서 구한 sparse table을 이용하여 빠르게 목표한 값으로 접근한다. 

    n이 0보다 큰 동안, x = f[x][int(log2(n))];  n -= (1 << int(log2(n))) 을 수행하면 가능하다.  n 중첩을

    1씩 좁혀나가는 것이 아니라 2의 승수단위로 좁혀나가는 것으로 탐색시간을 극적으로 줄이는 방식이다.

 

⑤ ④의 작업이 모두 끝났을 때, x의 값이 해당 쿼리에 대한 정답이 된다.

import sys
from math import log2, ceil

input = sys.stdin.readline



def sol17435():
    # 주어진 함수값의 수
    m = int(input())
    
    # 500000 <= 2^k
    k = ceil(log2(500000))
    
    # sparse table :  f[x][k] = fk(x)
    f = [[-1] * k for _ in range(m + 1)]
    fn = list(map(int, input().split()))
    for i in range(m):
        f[i + 1][0] = fn[i]

    # 점화식 f[i][j] = f[f[i][j-1]][j-1] 에 따라 sparse table 나머지 채우기
    for j in range(1, k):
        for i in range(1, m + 1):
            f[i][j] = f[f[i][j - 1]][j - 1]

    # 쿼리의 수
    q = int(input())
    
    # 각 쿼리의 정답 구하기
    answer = []
    for _ in range(q):
        # 함수 중첩수 n, 인자값 x
        n, x = map(int, input().split())
        
        # 중첩수가 0이 되면 종료
        while n:
            # f[x][d] == f2^d(x)
            d = int(log2(n))
            x = f[x][d]
            
            # 2^d 중첩만큼 처리했기 때문에 남은 중첩수에서 감산
            n -= (1 << d)
            
        # 정답리스트에 쿼리의 정답 저장
        answer.append(x)
       
    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 sparse table은 말그대로 어떤 규칙에 따른 값을 지수단위로 듬성듬성하게 구해두고 원하는 값까지 로그시간으로

빠르게 접근하기 위한 자료구조이다.  이 문제의 경우 sparse table 은 구성하는데에 O(MlogN)의 시간복잡도를,

쿼리 하나를 처리하는데 O(logN) 의 시간복잡도를 보인다.(경우에 따라서는 O(1) 수준의 성능을 보일때도 있다고 한다.)

sparse table은 정적인 데이터에서의 쿼리만을 처리할 수 있기 때문에 동적 데이터(데이터가 삽입, 삭제, 변경되는 경우)

에는 대응이 불가능하다.  만약 정적인 데이터에 대해 대량의 쿼리를 처리해야하는 경우라면 sparse table을 사용하는

것을 고려해볼만 할 것 같다.   

 

 

3. LCA 2 (https://www.acmicpc.net/problem/11438)

 이제 1번 문제에서 두 번째로 사용했던 알고리즘에 sparse table을 적용할 차례다.  노드의 수가 최대 100,000개로

늘어나고 쿼리의 수도 최대 100,000개이기 때문에 O(NM)의 시간복잡도를 보이는 기존의 방식으로는 해결할 수 없다.

두 노드의 깊이를 동일하게 맞추는 작업과 동일한 깊이에 있는 두 노드의 가장 가까운 조상 노드를 찾아가는 작업을 

sparse table을 사용하여 O(logN) 수준으로 최적화한다.  이를 구현한 코드는 다음과 같다.

 

import sys
import math

input = sys.stdin.readline


# sparse table(희소 테이블)을 활용하여 문제를 더 효율적으로 해결할 수 있다. 절차는 다음과 같다.
# 1. 각 노드의 깊이를 나타낼 depth 리스트와 부모노드를 나타낼 parent 리스트를 생성한다
#    이 때, parent 리스트는 2차원 리스트이며 parent[i][k] 는 노드 i의 2^k 번 거슬러올라간 조상노드를 의미한다.
#
# 2. dfs 혹은 bfs 로 트리를 순회하며 각 노드의 1번 거슬러올라간 조상노드(parent[i][0])와 깊이를 구한다.
#    이 때, 점화식 parent[i][j] = parent[parent[i][j-1]][j-1] 를 사용해서 O(NK) 의 시간 복잡도로 작업을 수행한다.
#    K는 log(2, N) 을 올림한 값과 같다.
#
# 3. 각 질의(LCA 를 구할 두 노드 u, v) 에 대해 두 노드의 깊이가 같아질 때 까지
#    보다 깊이가 큰 노드를 끌어올린다.
#
# 4. 두 노드가 서로 같다면 그대로 어느 한쪽 노드를 출력
#
# 5. 그렇지 않다면 두 노드의 가장 먼 공통 조상으로부터 처음으로 두 노드의 조상 노드가 달라질 때 까지 탐색
#    탐색 종료 후 노드 u 또는 v의 부모노드를 출력
def sol11438():
    # 각 노드의 깊이와 첫 번째 부모를 구하기 위한 탐색 함수
    def bfs(root):
        q = [root]
        while q:
            nq = []
            for cur in q:
                for c in g[cur]:
                    if depth[c] < 0:
                        depth[c] = depth[cur] + 1
                        parent[c][0] = cur
                        nq.append(c)
            q = nq

    # 노드의 수
    n = int(input())

    # sparse table 의 크기
    k = math.ceil(math.log2(n))

    # 그래프 생성
    g = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        a, b = map(int, input().split())
        g[a].append(b)
        g[b].append(a)

    # 각 노드의 깊이 리스트
    depth = [-1] * (n + 1)

    # 각 노드의 조상 노드 리스트 (sparse table)
    parent = [[-1] * k for _ in range(n + 1)]

    # 루트 노드인 1의 깊이를 0으로 하고 탐색 시작
    depth[1] = 0
    bfs(1)

    # sparse table 채우기
    for j in range(1, k):
        for i in range(1, n + 1):
            parent[i][j] = parent[parent[i][j - 1]][j - 1]

    # 각 질의에 대해 정답 구하기
    answer = []
    for _ in range(int(input())):
        u, v = map(int, input().split())

        # 노드 u가 노드 v 보다 깊이가 크도록 한다
        if depth[u] < depth[v]:
            u, v = v, u

        # 두 노드의 깊이가 같아질 때 까지 노드 u를 끌어올린다
        while depth[u] - depth[v]:
            u = parent[u][int(math.log2(depth[u] - depth[v]))]

        # 두 노드가 같지 않다면
        if u != v:
            # 가장 먼 공통 조상 노드로부터 처음으로 두 노드의 조상 노드가 달라질 때 까지 탐색
            for j in range(math.ceil(math.log2(depth[u])), -1, -1):
                if parent[u][j] != parent[v][j]:
                    u = parent[u][j]
                    v = parent[v][j]

            # 현재 u와 v는 공통 조상노드의 자식 노드인 상태이므로 둘 중 하나는 부모노드를 방문
            u = parent[u][0]

        # 노드 u 혹은 v를 반환
        answer.append(u)

    # 출력 형식에 맞춰 정답리스트 반환
    return '\n'.join(map(str, answer))

 

 지난 글에서 위상 정렬을 구현하는 방법으로 진입 차수 테이블과 큐를 사용한 방법을 알아보았다.  이번에는

위상 정렬을 구현하는 다른 방법인 dfs(깊이우선탐색)를 활용한 방법을 알아보고 몇 개의 문제를 해결해본다.

 

1. 줄 세우기(https://www.acmicpc.net/problem/2252) - 구현

 진입 차수 테이블과 큐를 사용한 방법의 경우, 진입 차수가 0인 노드를 모드 큐에 넣으며 탐색해나가는

bfs(너비우선탐색)의 형태로 구현된다. 그리고 그 방식에서는 큐에서 뽑혀 탐색된 노드 순서가 곧 정렬 순서가 되었다. dfs를 활용한 방식에서는 정확히 그 반대라고 생각하면 편하다. 즉, 탐색이 가장 먼저 종료된 순서의 역순이 곧

정렬 순서가 된다.  이 방식으로 줄 세우기 문제를 해결한 코드는 다음과 같다.

 

import sys
sys.setrecursionlimit(100000)

input = sys.stdin.readline


def sol2252():
    # 학생 수 n, 키 비교 결과 수 m
    n, m = map(int, input().split())
    
    # 그래프, 진입 차수 테이블 생성
    # bfs 형태의 구현과 달리 진입 차수 테이블은 생성이후 수정할 일이 없음
    g = [[] for _ in range(n + 1)]
    degree = [0] * (n + 1)
    for _ in range(m):
        a, b = map(int, input().split())
        g[a].append(b)
        degree[b] += 1
    
    # dfs를 활용한 위상 정렬 구현
    def dfs(d):
        nonlocal answer
        # 현재 노드 방문처리
        visit[d] = True
        
        # 아직 방문하지 않은 자식노드 방문
        for child in g[d]:
            if not visit[child]:
                dfs(child)
               
        # 탐색이 종료된 노드를 answer에 삽입
        answer.append(d)
        
    # 방문여부 리스트
    visit = [False] * (n + 1)
    
    # 정렬 결과 리스트
    answer = []
    
    # 처음부터 진입 차수가 0인 노드로부터 탐색 시작 
    for i in range(1, n + 1):
        if not degree[i]:
            dfs(i)
            
    # 정렬 결과 리스트를 뒤집어 출력 형식에 맞춰 반환
    return ' '.join(map(str, answer[::-1]))

 

 

2. ACM Craft (https://www.acmicpc.net/problem/1005)

 건설 규칙과 건물을 완성하는데 걸리는 시간이 주어졌을 때, 특정 건물을 건설하는데 걸리는 최소시간을 구하는 문제.

건물을 노드, 건설 규칙을 간선으로 생각하면 위상 정렬로 쉽게 해결할 수 있는 문제이다.  다만 각 노드(건물)까지

소요되는 최소 시간을 매번 다시 계산할 순 없기 때문에 간단한 동적계획법을 적용할 필요가 있다. 이를 해결한 코드는

다음과 같다.

 

import sys

sys.setrecursionlimit(100000)
input = sys.stdin.readline


def sol1005():
    # dfs(d) = dp[d] 는 d 번 건물을 건설하기 위해 필요한 최소 비용
    def dfs(d):
        # 아직 건물 d를 건설하기 위한 최소비용이 계산되지 않았다면 계산에 들어간다.
        # 건물 d를 건설하기 위한 최소비용은 반드시 먼저 건설되어야할 건물의 최소 비용 중 
        # 가장 큰 비용에 건물 d를 건설하기 위해 필요한 비용을 합친 값이 된다.
        if dp[d] == -1:
            dp[d] = max([dfs(p) for p in g[d]], default=0) + cost[d]
            
        # 건물 d를 건설하기 위한 최소 비용을 반환
        return dp[d]

    # 케이스별 정답 리스트
    answer = []
    for t in range(int(input())):
        # 건물 수 n, 건설 규칙 수 k
        n, k = map(int, input().split())
        
        # 각 건물을 완성하는데 필요한 시간
        cost = [0, *map(int, input().split())]

        # 그래프를 생성
        g = [[] for _ in range(n + 1)]
        for _ in range(k):
            x, y = map(int, input().split())
            
            # 이 문제의 경우 정확하게 건물 w를 짓는데 드는 최소 비용만을 구하는 문제
            # 노드 전체를 위상정렬하는 문제와 다르게 건물 w를 짓기 위해 지어야 하는 
            # 건물들의 최소 비용만을 탐색하는 것이 효율적이기 때문에
            # 역탐색을 위해 그래프의 인접리스트에 갈 수 있는 노드가 아닌 
            # 현재 노드로 올 수 있는 노드를 저장한다.
            g[y].append(x)

        # dfs(w)를 구하여 정답 리스트에 저장
        dp = [-1] * (n + 1)
        w = int(input())
        answer.append(dfs(w))

	# 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 

 

 위상 정렬의 개념과 이를 활용하여 해결할 수 있는 문제를 알아본다.

 

1. 위상 정렬(Topological Sorting)

 위상 정렬은 간단히 말하자면 순서를 가진 항목들을 줄 세우는 정렬방식이다.  예를 들어 1은 반드시 3보다 앞에 오고

2는 반드시 4보다 앞에 온다고 해보자,  이 때, [1, 2, 3, 4]를 위상 정렬한 결과는 [1, 3, 2, 4], [1, 2, 3, 4], [2, 4, 1, 3],

[2, 1, 4, 3] 모두가 될 수 있다.  즉, 주어진 항목간의 순서를 유지하는 것을 전제로 정렬을 수행하는 것이다.  주로 어떤

작업들을 수행되어야 할 순서에 따라 배치하기 위해 사용되는 알고리즘이다.  정렬의 대상은 기준이 될 명확한 순서가

존재해야 하기 때문에 당연히 방향성을 가졌으며 사이클이 존재하지 않는 그래프의 형태로 나타낼 수 있어야 한다.

 

2. 줄 세우기(https://www.acmicpc.net/problem/2252) - 위상 정렬의 구현

 학생의 수(노드의 갯수)와 학생들의 키를 비교한 결과(간선)가 주어졌을 때 학생들을 키순서로 줄 세우는 문제

위상 정렬의 구현을 연습할 수 있는 문제이다. 구현의 절차는 다음과 같다.

 

① 우선 위상 정렬을 위해 필요한 것은 그래프진입 차수 테이블이다.  주어진 노드의 갯수와 간선을 이용하여

    이 두 가지를 먼저 구해야한다. 노드 X의 진입 차수는 X로 직접 올 수 있는 길을 가진 노드의 갯수를 의미한다.

 

② 진입 차수가 0인 노드를 순서에 상관없이 모두 에 집어넣는다.

 

③ 큐에서 노드를 꺼내서 해당 노드로부터 뻗은 간선을 모두 제거하고 진입 차수 테이블을 갱신한다.

 

④ ②, ③ 작업을 큐가 빌 때 까지 반복한다.

 

⑤ 작업이 끝났을 때, 큐에서 노드를 꺼낸 순서가 곧 위상 정렬된 순서가 된다.

 

이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol2252():
    # 학생의 수 n, 비교 결과의 수 m
    n, m = map(int, input().split())
    
    # 학생을 노드로, 비교 결과를 간선으로 하여 그래프로 표현
    # 각 노드(학생)의 진입 차수 테이블을 생성
    g = [[] for _ in range(n + 1)]
    degree = [0] * (n + 1)
    for _ in range(m):
        a, b = map(int, input().split())
        degree[b] += 1
        g[a].append(b)
        
    # 진입 차수가 0인 모든 노드를 큐에 삽입
    q = [i for i in range(1, n + 1) if not degree[i]]
 
 	# 정렬 결과 리스트
    answer = []
    
    # 큐가 빌 때 까지
    while q:
        nq = []
        # 큐에서 노드를 하나씩 꺼내 정렬 결과 리스트 answer 에 삽입
        # 해당 노드로부터 뻗은 간선을 제거하고 진입 차수를 갱신
        # 그 결과 진입 차수가 0이 된 노드를 모두 큐에 삽입
        for num in q:
            answer.append(num)
            for child in g[num]:
                degree[child] -= 1
                if degree[child] == 0:
                    nq.append(child)
        q = nq
        
    # 출력 형식에 맞춰 정렬 결과를 반환
    return ' '.join(map(str, answer))

 

 

3. 최종 순위(https://www.acmicpc.net/problem/3665)

 작년의 팀들의 순위, 올해 상대적 순위가 변경된 팀의 순서쌍이 주어졌을 때, 올해 팀들의 순위를 구하는 문제.

역시 위상 정렬을 사용하여 해결 가능하지만 작년의 순위와 변경된 순서쌍만으로 순위를 추측해야한다. 주의할 점은 

주어진 순서쌍의 순서는 실제로 변경된 순위와는 관계가 없다는 것이다.  예를들어, 작년의 순위가 1, 2 였을 때,

순서쌍이 1, 2로 주어지든 2, 1로 주어지든 최종 순위는 2, 1이 되어야 한다.

 

 문제를 해결하는 절차는 다음과 같다.

 

① 기존 순위를 토대로 간선을 생성한다.  예를들어 기존 순위가 5 4 3 2 1 일 경우, 간선은

    (5, 4), (5, 3), (5, 2), (5, 1), (4, 3), (4, 2), (4, 1), (3, 2), (3, 1), (2, 1) 이 된다.

 

② 생성한 간선으로 그래프와 진입 차수 테이블을 생성한다.

 

③ 순위가 변경된 팀의 순서쌍에서 더 낮은 순위였던 쪽을 a, 더 높은 순위였던 쪽을 b로 하여 그래프와

    진입 차수 테이블을 수정한다.

 

④ 그래프에 대해 위상 정렬을 수행한다.

 

⑤ 위상정렬의 수행한 결과 리스트의 길이가 총 노드의 갯수에 미치지 못할 경우 그래프에 사이클이 존재한다는

    의미이기 때문에 데이터에 일관성이 없어 순위를 정할 수 없는 경우에 속하므로 IMPOSSIBLE을 출력한다

 

⑥ 그렇지 않다면 정상적으로 위상 정렬이 완료된 것이기 때문에 정렬 결과를 출력한다.

 

이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol3665():
    # 케이스별 정렬 결과 리스트
    answer = []
    
    for t in range(int(input())):
        # 팀의 수(노드의 수)
        n = int(input())
        
        # 작년 순위에 따라 그래프와 차수 테이블 생성
        g = [set() for _ in range(n + 1)]
        degree = [0] * (n + 1)
        teams = list(map(int, input().split()))
        for i in range(n):
            a = teams[i]
            for j in range(i + 1, n):
                b = teams[j]
                g[a].add(b)
                degree[b] += 1

        # 변경된 순위에 따라 그래프와 차수 테이블 수정
        m = int(input())
        for _ in range(m):
            a, b = map(int, input().split())
            
            # 보다 낮은 순위였던 쪽이 a가 되게 한다
            if a not in g[b]:
                a, b = b, a
               
            g[b].discard(a)
            g[a].add(b)
            degree[a] -= 1
            degree[b] += 1
            
        # 정렬 결과
        res = []
        
        # 위상 정렬 수행
        q = [i for i in range(1, n + 1) if not degree[i]]
        while q:
            nq = []
            for num in q:
                res.append(num)
                for child in g[num]:
                    degree[child] -= 1
                    if not degree[child]:
                        nq.append(child)
            q = nq

		# 사이클이 발견된 경우 -> 일관성이 없어 정렬이 불가능한 경우
        if len(res) != n:
            answer.append('IMPOSSIBLE')
            
        # 정렬이 정상적으로 수행된 경우
        else:
            answer.append(' '.join(map(str, res)))

	# 출력 형식에 맞춰 정렬 결과 반환
    return '\n'.join(answer)

 

 

4. 문제집(https://www.acmicpc.net/problem/1766)

 문제집의 1~N까지의 문제를 풀 순서를 다음 세가지 조건에 따라 정하는 문제

1. N개의 문제는 모두 풀어야한다

2. 먼저 푸는 것이 좋은 문제가 있는 문제는 먼저 푸는 것이 좋은 문제를 반드시 먼저 풀어야 한다.

3. 가능하면 쉬운 문제부터 풀어야 한다. (문제 번호가 낮을수록 난이도가 낮다)

 

순서를 지켜야한다는 점에서 알 수 있듯이 위상정렬을 사용하여 해결 가능한 문제이다.  다만 같은 단계에서

진입 차수가 0이라면 순서가 상관없었던 기존의 위상정렬로는 3번 조건을 만족시킬 수 없다.  여기서는 

기존의 큐를 heapq 로 대체하는 것으로 문제를 해결할 수 있다. 구현은 다음과 같다.

 

import sys
from heapq import heappush, heappop, heapify

input = sys.stdin.readline


def sol1766():
    # 문제의 갯수 n, 먼저 푸는게 좋은 문제의 정보의 갯수 m
    n, m = map(int, input().split())
    
    # 그래프, 진입 차수 테이블 생성
    g = [[] for _ in range(n + 1)]
    degree = [0] * (n + 1)
    for _ in range(m):
        a, b = map(int, input().split())
        g[a].append(b)
        degree[b] += 1

    # 최초에 진입 차수가 0인 노드(문제)를 큐에 담는다
    q = [i for i in range(1, n + 1) if not degree[i]]
    
    # 큐를 heapq로 변환
    # Python의 heapq는 기본적으로 최소 힙이기 때문에
    # 더 작은 숫자, 즉 더 쉬운 문제부터 뽑게된다.
    heapify(q)
    
    # 정렬 결과 리스트
    answer = []
    
    # 위상 정렬 실행
    # 기존의 방식과 같지만 큐에서 값을 꺼내고 다시 넣는 동작을
    # heappop 과 heappush를 사용하여 수행한다.
    # 이것으로 매번 가능하면 가장 쉬운 문제를 해결하는 3번 조건을 만족시킨다.
    while q:
        cur = heappop(q)
        answer.append(cur)
        for c in g[cur]:
            degree[c] -= 1
            if not degree[c]:
                heappush(q, c)

    # 출력 조건에 따라 정답 리스트 반환
    return ' '.join(map(str, answer))

 heap을 사용하면 항상 정렬된 것과 마찬가지인 상태를 유지할 수 있다는 점만 떠올리면 쉽게 해결할 수 있는 문제였다.

 

 

위상정렬의 구현은 다음에 알아볼 dfs를 사용한 구현이 보다 간단하지만 큐를 사용한 구현이 풀이에 적합한 경우도 있기 때문에 두 방법 모두 숙지해둘 필요가 있다.

 누적 합은 2차원 배열에서도 1차원에서와 거의 동일하게 활용 가능하다.  이번에는 2차원 배열에서의

누적 합의 성질과 이를 활용하여 풀 수 있는 문제들을 알아본다.

 

1. 2차원 배열의 누적합 구하기

 1차원 배열에서는 단순히 이전 인덱스의 값을 더해주면 끝이었지만 2차원 배열에서 누적합을 구하려면 각 행의

누적 합과 각 열의 누적 합을 구하는 두 단계를 거쳐야한다.  예시로 2차원 배열 하나의 누적 합을 구해보자.

 

① 먼저 각 행의 누적합을 구한다.

 

② 다음으로 각 열에 대한 누적 합을 구한다.

 

③ 2차원 행렬에서의 누적합 배열이 완성된다.  ①과 ②의 순선느 서로 뒤바뀌더라도 문제가 없다.

 

 

2. 2차원 누적 합의 성질

 2차원 배열에서의 부분 합을 구하거나 여러 구간에 대한 증감 적용을 한꺼번에 처리하는 데

활용할 수 있는 성질들을 가지고있다. 

 

누적합 S(r, c)는 2차원 배열 A에 대해 다음과 같은 성질을 가진다.

① S(r, c) = sum([A[i][j] for i in range(r) for j in range(c)]) 

② S(0, 0) = A[0][0]

③ S(r, c) = S(r-1, c) + A[r][c]

④ sum([A[i][j] for i in range(r1, r2+1) for j in range(c1, c2+1)]) = S(r2, c2) + S(r1-1, c1-1) - S(r1, c2-1) - S(c1-1, r2)

② 번 성질이 성립하는 이유

1차원 배열에서의 누적 합과 조금 다르지만 거의 유사한 성질을 가진 것을 알 수 있다.

 

 

3. 2차원 배열의 합

N*M 크기의 2차원 배열이 주어졌을때 K 개의 특정 구간에 대한 부분 합을 구하는 문제.  단순히 매번 반복문을 돌려 합을 구할 경우 O(NMK)의 시간복잡도를 보인다.  하지만 누적합의 4번 성질을 이용하면 O(NM+K)만에 해결할 수 있다.

문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol2167():
    # 행렬의 크기 n, m
    n, m = map(int, input().split())
    
    # 주어진 행렬 seq
    # 누적 합의 성질 4번을 사용할 때, 0번 인덱스의 이전 인덱스를 참조할 경우를 대비해 
    # 행/열의 끝에 0을 추가
    seq = [[*map(int, input().split()), 0] for _ in range(n)]
    seq.append([0] * m)
    seq.append([0] * m)
    
    # 2차원 배열의 누적합 배열을 구한다
    for i in range(n):
        for j in range(m):
            # 현재 위치의 누적합을 구하기 전에 같은 열의 다음 행의 값에 
            # 현재위치의 값을 더해주는 것으로 열에 대한 누적합을 동시에 계산해나갈 수 있다.
            seq[i + 1][j] += seq[i][j]
            seq[i][j] += seq[i][j - 1]

	# 정답 리스트
    answer = []
    
    # 누적 합의 성질 4번을 사용하여 부분 합을 구하고 정답 리스트에 저장
    for _ in range(int(input())):
        i, j, x, y = map(int, input().split())
        answer.append(seq[x-1][y-1] + seq[i - 2][j - 2] - seq[x-1][j - 2] - seq[i - 2][y-1])

    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 

 

4. 

 누적 합(prefix sum)은 그 특성상 배열의 부분합을 빠르게 구하는데 굉장히 유용하다. 이번에는 1차원 배열에서의 누적 합의 성질과 이를 이용하여 해결 가능한 문제들을 다뤄본다.

 

1. 누적 합의 성질

 어떤 1차원 배열 A = [a(1), a(2), ... , a(n)] 에 대해, 누적합 S(k)는 다음과 같은 성질을 가진다.

 ① S(k) = a(1) + a(2) + ... + a(k)

 ② S(1) = a(1)

 ③ S(k) = S(k-1) + a(k)

 ④ a(i) + a(i+1) + ... + a(j-1) + a(j) = S(j) - S(i-1) 

 

 

2. 구간 합 구하기 4(https://www.acmicpc.net/problem/11659)

 매우 간단한 누적합 활용문제이다.

import sys

input = sys.stdin.readline


def sol11659():
    # 수열의 크기 n, 질의의 갯수 m
    n, m = map(int, input().split())

    # n개의 수로 이루어진 수열
    seq = [0, *map(int, input().split())]

    # 수열의 누적합 전처리
    for i in range(1, n + 1):
        seq[i] += seq[i - 1]

    # 각 질의에 대해 정답을 계산하여 answer 에 삽입
    answer = []
    for _ in range(m):
        i, j = map(int, input().split())
        answer.append(seq[j] - seq[i - 1])

    # 출력형식에 맞춰 정답리스트 반환환
    return '\n'.join(map(str, answer))

a(i) + a(i+1) + ... + a(j-1) + a(j) = S(j) - S(i-1) 임을 이용하여 문제를 해결하였다.

 

 

3. 광고 삽입(https://programmers.co.kr/learn/courses/30/lessons/72414)

 영상 전체의 길이와 시청자들이 어느 구간을 재생했는지의 로그들이 주어질 때, 광고를 삽입할 최적의 위치를 찾기위해 가장 재생횟수가 많은 구간의 시작시간을 구하는 문제이다.  만약 재생횟수가 가장 많은 구간이 여러개라면 그 중 가장 빠른 시작시간을 반환한다. 위의 문제보다는 복잡하지만 역시 누적 합을 활용하여 해결할 수 있는 문제이다. 

 

주어지는 데이터는 동영상 전체의 재생시간 play_time 과 광고의 재생시간 adv_time, 영상의 재생구간 로그들이 담긴 리스트 logs이다.  문제를 해결하는 과정은 다음과 같다.

 

① 먼저, 제한사항에 광고의 재생시간은 동영상의 재생시간보다 짧거나 같게 주어진다는 조건이 있다. 

    광고의 재생시간이 영상의 재생시간보다 짧을 경우에는 평범하게 최적의 광고시간을 탐색하면 그만이지만

    같을 경우에는 애초에 탐색할 필요가없다.  영상의 시작시간이 곧 광고의 시작시간이 될 수 밖에 없기 때문이다.

 

② play_time 과 adv_time은 'HH:MM:SS' 의 문자열 형태이기 때문에 다루기 쉽도록 정수형태로 파싱해줄 필요가 있다.

    가장 작은 초단위로 시간을 변환한다. HH * 3600 + MM * 60 + SS 로 간단하게 변환 가능하다.

 

③ 다음은 광고가 삽입될 최적의 위치를 찾기위해 로그를 분석해야한다.  가장 먼저 떠오르는 방식은 1초에 한칸씩

    0으로 초기화된 배열을 할당하여 재생된 부분마다 1씩 더해주는 것으로 가장 재생수가 많은 구간을 탐색하는

    것이다.  다행히도 play_time의 범위가 00:00:00 부터 99:59:59,  즉, 0초부터 359999초 까지라는 조건이 있기 때문에

    배열의 형태로 충분히 나타낼 수 있다.

 

④ 예를 들어 로그가 '00:00:00-00:00:04',  '00:00:02-00:00:07', '00:00:06-00:00:10' 의 세 개라고 생각해보자.

    문제의 제한사항에 "예를 들어, 00시 00분 01초부터 00시 00분 10초까지 동영상이 재생되었다면, 동영상

    재생시간은 9초 입니다." 라는 조건이 있다. 즉, 종료시간은 재생구간에 들어가지 않는다는 것이다.

    이에 따라 위 로그를 배열상에 표현하면 아래와 같은 형태가 된다.

    문제는 여기서 발생한다. 로그 한번마다 구간내의 값을 모두 1 씩 증가시키려면 재생시간의 길이가 N, 로그의 갯수가

    M일 때 최악의 경우(모든 로그의 범위가 영상의 재생시간과 같을 경우) O(NM) 의 시간복잡도를 보인다.

    이 문제에서 N은 최대 360000, M은 최대 30만의 크기를 가지기에 이 알고리즘은 매우 비효율적이다.

   

⑤ 여기서 우리는 누적 합의 성질 중 하나인 S(k) = S(k-1) + a(k) 를 활용할 수 있다.  이번엔 로그의 시작 시간과

    종료 시간에만 시청횟수의 증감을 더해보자.  그러면 다음과 같은 형태가 된다.

    각 로그의 시작 시간인 0, 2, 6초에는 +1을, 종료 시간인 4, 7, 10초에는 -1을 더하였다. 이제 위 테이블에서

    누적합의 성질에 따라 누적합을 구해보자. 그러면 다음과 같은 형태가 된다.

    한눈에 봐도 ④ 에서 구했던 로그마다 전 구간의 값을 1씩 증가시키는것으로 얻었던 배열과 같음을 알 수 있다.

    게다가 이 방식은 O(N+M)의 시간복잡도로 위의 O(NM)의 알고리즘과 같은 결과를 얻을 수 있다.  이와 같이

    누적 합을 활용하면 어떤 배열에 여러 구간에 대한 증/감을 반영하는 작업을 굉장히 효율적으로 처리할 수 있다.

 

⑥ 여기까지 구하고나면 다음은 어렵지 않다.  위에서 구한 배열에서 영상의 시작시간부터 시작하여 광고 시간만큼의

    구간의 합이 최대치가 되는 구간의 시작시간을 구하면 된다. 최댓값이 여러개라면 가장 빠른 시간을 구하라는 조건도

    처음 찾은 최댓값이 더 큰 값이 오기 전에는 갱신되지 않도록 하면 자동으로 충족된다.  여기서 구간합이 최대가 되는

    경우를 찾는 것도 두 가지 방법이 있다.

 

⑦ 하나는 마찬가지로 누적 합을 이용한 방법이다.  a(i) + a(i+1) + ... + a(j-1) + a(j) = S(j) - S(i-1) 임을 이용하여,

    위 배열의 누적합을 다시한번 구한 뒤 구간마다 구간합을 빠르게 구할 수 있다.

 

⑧ 다른 하나는 투 포인터를 활용하는 방법이다.  첫 번째 구간의 합만을 구한 뒤, 다음 구간으로 넘어갈때마다

    첫 번째 요소를 빼고 새로운 요소를 더하는 식으로 구간합을 갱신해나갈 수 있다.

 

⑨ 마지막으로 구한 시작시간을 다시 'HH:MM:SS' 의 형태의 문자열로 변환하여 반환하면 된다.

 

위 절차에 따라 문제를 해결한 코드는 다음과 같다.

 

1) 누적 합의 성질을 활용하여 최대 구간합을 구한 방식 (⑦)

def solution(play_time, adv_time, logs):
    # ① 광고의 재생시간이 영상의 재생시간과 같다면 영상의 첫 부분에 삽입하면 된다.
    if play_time == adv_time:
        return '00:00:00'
    
    # ② 영상 재생시간과 광고 재생시간을 초단위의 정수로 변환
    pt, at = stoh(play_time), stoh(adv_time)
    
    # ③ 시간별 재생횟수를 나타내기위한 타임라인 배열을 생성
    timeline = [0] * 360001
    
    # ④, ⑤ 로그를 분석하여 재생 시작시간과 종료시간을 구한 뒤
    # 누적 합을 활용하여 타임라인에 재생 횟수를 표기
    for log in logs:
        s, e = map(stoh, log.split('-'))
        timeline[s] += 1
        timeline[e] -= 1
    for i in range(1, 360001):
        timeline[i] += timeline[i - 1]
        
    # ⑦, ⑧ 영상의 시작시간부터 timeline의 at(광고 재생시간)크기의 구간합이 최대가 되는 경우를 탐색
    # 다시 한번 누적 합을 구하는 전처리작업을 수행하여 구간 합을 더 쉽게 구할 수 있도록 한다.
    for i in range(1, 360001):
        timeline[i] += timeline[i - 1]
        
    # 최대 구간 합 maxw, 정답으로 반환할 시작시간 answer
    maxw, answer = timeline[at-1], 0
    
    # 영상의 시작시간부터 탐색 시작
    for s in range(1, pt - at+1):
        # 누적 합의 성질을 활용하여 구간 합을 구한다
        w = timeline[s+at-1] - timeline[s-1]
        
        # 만약 새로운 구간합이 최대 구간합 maxw보다 크다면 maxw와 answer 갱신
        if w > maxw:
            maxw, answer = w, s
            
    # 광고를 삽입할 최적의 시작시간을 문자열로 변환하여 반환
    answer = htos(answer)
    return answer


# 문자열 형태의 시간을 초단위의 정수로 변환
def stoh(times):
    h, m, s = map(int, times.split(':'))
    return h * 3600 + m * 60 + s


# 초단위의 정수를 문자열 형태의 시간으로 변환
def htos(sec):
    h = sec // 3600
    sec %= 3600
    m = sec // 60
    s = sec % 60
    return '%02d:%02d:%02d' % (h, m, s)

 

2) 투 포인터를 활용하여 최대 구간합을 구한 방식 (⑧)

def solution(play_time, adv_time, logs):
    # ① 광고의 재생시간이 영상의 재생시간과 같다면 영상의 첫 부분에 삽입하면 된다.
    if play_time == adv_time:
        return '00:00:00'
    
    # ② 영상 재생시간과 광고 재생시간을 초단위의 정수로 변환
    pt, at = stoh(play_time), stoh(adv_time)
    
    # ③ 시간별 재생횟수를 나타내기위한 타임라인 배열을 생성
    timeline = [0] * 360001
    
    # ④, ⑤ 로그를 분석하여 재생 시작시간과 종료시간을 구한 뒤
    # 누적 합을 활용하여 타임라인에 재생 횟수를 표기
    for log in logs:
        s, e = map(stoh, log.split('-'))
        timeline[s] += 1
        timeline[e] -= 1
    for i in range(1, 360001):
        timeline[i] += timeline[i - 1]

    # ⑦, ⑧ 영상의 시작시간부터 timeline의 at(광고 재생시간)크기의 구간합이 최대가 되는 경우를 탐색
    
    # 최대구간합 maxw, 정답으로 반환할 시작시간 answer
    maxw, answer = sum(timeline[:at]), 0
    
    # 초기 구간합 w = maxw
    w = maxw
    
    # 영상의 시작시간부터 탐색 시작
    for s in range(1, pt - at+1):
        # 투 포인터를 활용하여 구간합을 갱신해나가는 방식을 사용
        w = w - timeline[s-1] + timeline[s+at-1]
        
        # 만약 새로운 구간합이 최대 구간합 maxw보다 크다면 maxw와 answer 갱신
        if w > maxw:
            maxw, answer = w, s
            
    # 광고를 삽입할 최적의 시작시간을 문자열로 변환하여 반환
    answer = htos(answer)
    return answer


# 문자열 형태의 시간을 초단위의 정수로 변환
def stoh(times):
    h, m, s = map(int, times.split(':'))
    return h * 3600 + m * 60 + s


# 초단위의 정수를 문자열 형태의 시간으로 변환
def htos(sec):
    h = sec // 3600
    sec %= 3600
    m = sec // 60
    s = sec % 60
    return '%02d:%02d:%02d' % (h, m, s)

 

구간 합을 활용하여 효율적으로 문제를 해결해야하는 경우는 코딩테스트에서 꽤 자주 보인다.  잘 기억해두자.

 이전 글에서 다루었던 트라이 자료구조를 활용하여 해결할 수 있는 문제를 살펴본다.

 

1. 개미굴(https://www.acmicpc.net/problem/14725)

 탐사용 개미로봇은 트리 형태를 이루는 개미굴을 입구(루트)부분에서 시작하여 아래로 내려가며 탐색을 진행하고 더이상 내려갈 수 없어지면 지금까지 거쳐온 방의 정보를 전송하고 작동을 중지한다.  개미로봇들이 보낸 정보를 토대로 개미굴의 형태를 그려내는 문제.

 

 단순히 개미로봇이 보낸 정보들을 트라이에 삽입한 후 트라이의를 출력형식에 맞춰 문자열화하는 것으로 간단하게 해결 가능한 문제이다.  문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


# 트라이의 노드 클래스
# data(키 값), child(자식노드 목록), finished(데이터의 마지막 노드인지 여부)의 세 필드값을 가진다.
class Node:
    # 생성자
    def __init__(self, data=None, finished=False):
        self.data = data
        self.finished = finished
        self.child = {}

    # 노드를 문자열로 표현
    def __str__(self, depth=0):
        res = [str(self.data)]
        for c in self.child.values():
            res.append('--' * depth + c.__str__(depth + 1))
        return '\n'.join(res)


# 트라이 클래스
class Trie:
    # 생성시 빈 노드를 생성하여 루트노드로 설정
    def __init__(self):
        self.root = Node()

    # 삽입 메소드
    def add(self, e):
        # 루트 노드에서 탐색 시작
        cur = self.root

        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for i in e:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 현재 노드의 자식노드 리스트에 삽입한다.
            if i not in cur.child:
                cur.child[i] = Node(i)
            cur = cur.child[i]
        # 현재 노드가 데이터의 마지막 요소를 키값으로 하는 노드임을 표시
        cur.finished = True

    # 탐색 메소드
    def search(self, key):
        # 루트 노드에서 탐색 시작
        cur = self.root

        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for c in key:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 찾으려는 데이터가 존재하지 않음을 결과로 반환한다.
            if c in cur.child:
                cur = cur.child[c]
            else:
                return None
        # 탐색이 종료됐을 때, 마지막 노드가 데이터의 마지막 요소를 키 값으로 하는 노드로
        # 표시되어있다면 탐색 성공.  그렇지 않다면 데이터가 존재하지 않음을 결과로 반환한다.
        return cur.data if cur.finished else None

    # 트라이를 문자열로 표현
    def __str__(self):
        return '\n'.join([c.__str__(1) for c in self.root.child.values()])


def sol14725():
    # 로봇개미의 수
    n = int(input())

    # 트라이 인스턴스
    t = Trie()

    # 로봇개미가 가져온 정보를 트라이에 삽입
    for e in sorted([input().split()[1:] for _ in range(n)]):
        t.add(e)

    # 트라이 인스턴스를 문자열화 하여 반환
    return str(t)

 

 

2. 문자열 집합(https://www.acmicpc.net/problem/14425)

 N개의 문자열로 이루어진 집합 S 와 M개의 문자열이 주어졌을 때, M개중 몇 개의 문자열이 S에 속해있는지 구하는 문제.  사실 Python의 경우 set 등을 활용하여 더 빠르고 쉽게 해결할 수 있지만 연습을 위해 트라이를 사용하여 풀어보았다.  코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


# 트라이의 노드 클래스
# data(키 값), child(자식노드 목록), finished(데이터의 마지막 노드인지 여부)의 세 필드값을 가진다.
class Node:
    def __init__(self, data=None, finished=False):
        self.data = data
        self.finished = finished
        self.child = {}


# 트라이 클래스
class Trie:
    # 생성시 빈 노드를 생성하여 루트노드로 설정
    def __init__(self):
        self.root = Node()

    # 삽입 메소드
    def add(self, e):
        # 루트 노드에서 탐색 시작
        cur = self.root

        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for i in e:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 현재 노드의 자식노드 리스트에 삽입한다.
            if i not in cur.child:
                cur.child[i] = Node(i)
            cur = cur.child[i]
        # 현재 노드가 데이터의 마지막 요소를 키값으로 하는 노드임을 표시
        cur.finished = True

    # 탐색 메소드
    def search(self, key):
        # 루트 노드에서 탐색 시작
        cur = self.root

        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for c in key:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 찾으려는 데이터가 존재하지 않음을 결과로 반환한다.
            if c in cur.child:
                cur = cur.child[c]
            else:
                return None
        # 탐색이 종료됐을 때, 마지막 노드가 데이터의 마지막 요소를 키 값으로 하는 노드로
        # 표시되어있다면 탐색 성공.  그렇지 않다면 데이터가 존재하지 않음을 결과로 반환한다.
        return cur.data if cur.finished else None


def sol14425():
    # 집합에 속한 문자열의 갯수 N, 검사할 문자열의 갯수 M
    n, m = map(int, input().split())

    # 입력된 모든 문자열
    strings = sys.stdin.read().split()

    # N개의 문자열을 트라이에 삽입
    t = Trie()
    for i in range(n):
        t.add(strings[i])

    # M개의 문자열을 모두 탐색하여 몇 개의 문자열이 트라이에 존재하는지 센다
    cnt = 0
    for i in range(n, n + m):
        cnt += (1 if t.search(strings[i]) else 0)

    # 결과반환
    return cnt

단순히 N개의 문자열을 모두 트라이에 넣고 M개의 문자열에 대해 탐색을 진행하는 것으로 해결한 풀이이다.  시간복잡도는 1<=L(문자열 최대길이)<=500,  1<=N<=10,000, 1<=M<=10,000 일 때  O(L(N+M)) 의 시간복잡도를 보이기에 PyPy3으로만 AC를 받을 수 있었다.  Python3으로는 추후에 좀더 최적화를 거쳐 다시 제출해보기로 한다.

 

 

3. 휴대폰 자판(https://www.acmicpc.net/problem/5670)

 자동완성 기능을 가진 휴대폰 자판으로 단어들을 타이핑할 때 필요한 타이핑 횟수의 평균값을 구하는 문제. 사전에 속해있는 단어의 갯수 N이 최대 10만이며 각 케이스에 주어지는 단어 길이의 총합도 100만으로 상당히 많은 양의 데이터를 처리해야하기 때문에 모든 단어마다 실제로 시뮬레이션을 돌려보는것은 시간초과가 발생할 것이다. 

 

각 단어가 삽입될 때, 거쳐간 노드의 카운트를 1씩 증가시키는 것으로 해당 노드가 몇개의 단어에 포함되어있는지 파악하고, dfs로 모든노드를 탐색하며 눌리지 않을 버튼을 키값으로 하는 노드를 제외한 모든 노드의 카운트값을 더하는 것으로 눌러야할 버튼 수 의 총합을 구할 수 있다.  문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


# 트라이의 노드 클래스
# data(키 값), child(자식노드 목록), finished(데이터의 마지막 노드인지 여부), cnt(거쳐간 단어의 수)의 네 개의 필드값을 가진다.
class Node:
    def __init__(self, data=None, finished=False):
        self.data = data
        self.finished = finished
        self.child = {}
        self.cnt = 0


# 트라이 클래스
class Trie:
    # 생성시 빈 노드를 생성하여 루트노드로 설정
    def __init__(self):
        self.root = Node(finished=True)

    # 삽입 메소드
    def add(self, e):
        # 루트 노드에서 탐색 시작
        cur = self.root

        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for i in e:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 현재 노드의 자식노드 리스트에 삽입한다.
            if i not in cur.child:
                cur.child[i] = Node(i)

            # 해당 노드를 거쳐간 문자열의 수를 1 증가시킨다.
            cur.child[i].cnt += 1

            cur = cur.child[i]
        # 현재 노드가 데이터의 마지막 노드임을 표시
        cur.finished = True

    # 트라이의 루트노드를 반환
    def get_root(self):
        return self.root


def sol5670():
    # 케이스의 정답 목록을 저장할 리스트
    answer = []

    for n in sys.stdin:
        # 케이스별 영단어의 갯수
        n = int(n)

        # 단어 리스트
        words = [input().strip() for _ in range(n)]

        # 트라이에 모든 단어를 삽입
        t = Trie()
        for word in words:
            t.add(word)

        # 루트노드의 모든 자식노드들에 대해 dfs 를 실행, 그 결과를 모두 더한다
        # res 값은 모든 영단어를 타이핑하기 위해 버튼을 눌러야할 횟수의 총합
        res = 0
        root = t.get_root()
        for cur in root.child.values():
            res += dfs(root, cur)

        # 평균값을 소숫점 셋째자리에서 반올림하여 answer 리스트에 삽입
        answer.append('%.2f' % (res / n))

    # 출력형식에 맞춰 정답리스트 반환
    return '\n'.join(answer)


# 탐색 함수
def dfs(parent, node):
    # 버튼을 눌러야 할 횟수
    res = 0

    # 현재 노드가 부모노드의 유일한 자식노드이거나
    # 부모노드가 데이터의 마지막 노드로 표시되어있을 경우
    # 자동완성이 되지 않기 때문에 버튼을 눌러야함
    if parent.finished or len(parent.child) > 1:
        res += node.cnt

    # 자식노드에 대해 탐색함수 재귀호출하여 결과값을 res 에 가산
    for c in node.child.values():
        res += dfs(node, c)

    # 버튼을 눌러야 할 횟수 반환
    return res

약간의 최적화가 필요하긴 했지만 아슬아슬하게 Python3 으로도 AC를 받을 수 있었다.   

 

 문자열을 보다 효율적으로 탐색할 수 있도록 저장하는 자료구조인 트라이(Trie)에 대해 알아본다.

 

 

1. 개요

 데이터의 빠른 탐색을 위한 자료구조라고 하면 보통 이진탐색트리나 B트리, 해시테이블 등을 떠올릴 것이다.  이러한 자료구조들의 삽입/탐색 연산은 데이터의 갯수(N)에 영향을 받는다.  이번에 알아볼 트라이(Trie)는 문자열을 저장하고 효율적으로 탐색하기 위한 자료구조로, 특이하게도 데이터의 갯수가 아닌 삽입하거나 탐색할 데이터(문자열)의 크기에 영향을 받는다. 

 

2. 삽입 연산

 예시로 트라이에 'hello' 라는 데이터를 삽입해보자.

 

트라이의 초기상태

먼저 루트노드의 자식노드중 키 값이 'h' 인 노드를 찾는다 현재 트라이는 비어있기 때문에 루트노드의 자식 중에는 'h'를 키값으로 가진 노드는 존재하지 않는다.  그러면 'h'를 키 값으로 하는 노드를 생성하여 루트노드의 자식노드로 집어넣는다.

 

'hello' 삽입과정

이제 새로 집어넣은 'h'를 키 값으로 하는 노드로 위치를 옮겨, 그 노드의 자식노드중 다음 문자인 'e'를 키값으로 하는 노드를 찾는다.  이번에도 'h' 노드의 자식노드는 비어있기 때문에 노드를 새로 만들어서 추가한다. 이것을 다섯글자 모두 삽입할 때 까지 반복한다.

 

'hello' 삽입완료

이것으로 'hello' 의 삽입이 완료되었다.  다음은 트라이에 'head' 를 집어넣어보자. 루트노드에서 시작하여 글자를 하나하나 비교해나간다.  이번에는 첫 글자 h가 루트노드의 자식노드에 존재한다. 해당노드로 위치를 옮긴다. 

 

다음은 'e'를 찾는다. 이번에도 'h' 노드의 자식노드에 'e'가 존재한다. 또다시 위치를 옮긴다.

 

이번에는 'e' 노드의 자식노드에 'a' 가 존재하지 않는다.  'a' 노드를 생성하여 자식노드에 삽입하고 삽입한 노드로 위치를 이동한다.

 

마지막으로 'd'가 'a' 노드의 자식노드에 존재하지 않기 때문에 노드를 생성하여 집어넣어준다.

 

 이것으로 'head'의 삽입이 완료되었다.  

 

그림으로만 살펴봐도 알겠지만 삽입과정이 굉장히 단순하다.  루트노드에서 시작하여 자식노드에서 데이터를 이루는 요소를 순차적으로 탐색해나가다가 존재하지 않는다면 생성하여 삽입해주는 방식이다.  물론 삽입 연산의 시간복잡도는 길이 L 인 문자열에 대해 O(L) 이다. 

 

 

3. 탐색

 탐색 알고리즘도 삽입과 거의 동일한 방식으로 이뤄진다.  다만 삽입 연산에서는 자식노드에 원하는 문자를 키 값으로 가진 노드가 없다면 만들어서 삽입하였지만 탐색의 경우 탐색에 실패했다는 결과가 바로 반환된다. 

 삽입 연산에서와 달리 일부 노드에 *표시가 있는 것을 확인할 수 있다. 이는 한 데이터의 끝을 표현한다.  데이터를 삽입할 때, 마지막 요소를 삽입할 경우 이 노드에서 끝나는 데이터가 존재함을 표시해주는 것으로 삽입하지 않은 데이터를 존재하는 것으로 오인할 위험을 미연에 방지하는 것이다.  예를들어 위 트라이에서 'hell'을 탐색해보자. 

 

이와 같이 'h', 'e', 'l', 'l' 네 글자 모두 탐색에 성공하였다.  하지만 'hell' 이라는 문자열은 삽입된 적이 없으며 단지 'hello'가 삽입되었기 때문에 탐색이 가능했을 뿐이다.  만약 데이터의 끝을 표시하지 않았다면 이 상황에 hell이 정말 들어있는 것인지 아닌지 구분할 방법이 없지만 *로 데이터의 마지막 노드에는 표시를 해두었기 때문에 'hell'이라는 문자열이 존재하지 않는다는 사실을 알 수 있다.

 

 

4. 구현

Python으로 트라이를 간단하게 구현한 코드는 다음과 같다. 

# 트라이의 노드 클래스
# data(키 값), child(자식노드 목록), finished(데이터의 마지막 노드인지 여부)의 세 필드값을 가진다.
class Node:
    def __init__(self, data=None, finished=False):
        self.data = data
        self.finished = finished
        self.child = {}


# 트라이 클래스
class Trie:
    # 생성시 빈 노드를 생성하여 루트노드로 설정
    def __init__(self):
        self.root = Node()

    # 삽입 메소드
    def add(self, e):
        # 루트 노드에서 탐색 시작
        cur = self.root
        
        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for i in e:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 현재 노드의 자식노드 리스트에 삽입한다.
            if i not in cur.child:
                cur.child[i] = Node(i)
            cur = cur.child[i]
        # 현재 노드가 데이터의 마지막 요소를 키값으로 하는 노드임을 표시
        cur.finished = True

    # 탐색 메소드
    def search(self, key):
        # 루트 노드에서 탐색 시작
        cur = self.root
        
        # 데이터의 각 요소를 키 값으로 하는 노드를 타고내려간다.
        for c in key:
            # 현재 노드의 자식노드 중에
            # 다음 요소를 키 값으로 하는 노드가 존재하지 않는다면 노드를 생성
            # 찾으려는 데이터가 존재하지 않음을 결과로 반환한다.
            if c in cur.child:
                cur = cur.child[c]
            else:
                return None
        # 탐색이 종료됐을 때, 마지막 노드가 데이터의 마지막 요소를 키 값으로 하는 노드로
        # 표시되어있다면 탐색 성공.  그렇지 않다면 데이터가 존재하지 않음을 결과로 반환한다.
        return cur.data if cur.finished else None

 

 저번 글에서 알아봤던 KMP 알고리즘을 활용하여 풀 수 있는 문제들에 대해 알아본다.

 

1. 문자열 제곱(https://www.acmicpc.net/problem/4354)

 문자열 a의 n 제곱을 문자열 a를 n번 이어붙이는 연산으로 정의하고(ex)  'abc'^3 == 'abcabcabc') 문자열 s가 주어졌을 때,  s 를 a^n 을 만족하는 가장 큰 n을 구하는 문제.  

 

 KMP 알고리즘 전체가 아니라 KMP 알고리즘에서 사용했던 실패함수, 즉 LPS를 구하는 알고리즘을 활용하여 풀 수 있는 문제이다.  문제로 주어진 각 케이스들은 문자열 s 의 lps값에 따라 다음의 세 가지 경우로 나눌 수 있다.

 

 

# case 1) len(lps)*2 < len(s)

len(lps) * 2(양 끝의 lps 길이를 더한 값)가 문자열의 길이보다 짧을 경우, 해당 문자열은 어떤 문자열의 반복으로 나타낼 수 없다. lps에 속하지 않은 문자열을 표현할 방법이 없기 때문이다.  이 경우 s = s ^ 1 로 밖에 나타낼 수 없기 때문에 답은 1이 된다.

 

# case 2) len(lps)*2 == len(s)

len(lps) * 2 가 문자열의 길이와 같을 경우, 해당 문자열은 lps ^ 2 로만 나타낼 수 있다.  n이 더 커지려면 lps 자체도 문자열의 반복으로 이루어져야 하는데, 그렇게 될 경우 당연히 lps는 더 길어지기 때문에 그런 경우는 있을 수 없다. 답은 2가 된다.

 

# case 3) len(lps)*2 > len(s) 

이 경우는 또다시 두 가지 경우로 나뉘어진다.  두 LPS가 겹친 부분의 문자열을 v, lps에서 공통부분을 뺀 문자열을 u 라고 하자.  이때 len(v)가 len(u)의 배수가 된다면 s = u ^ ((v // u) + 2) 로 나타낼 수 있게된다. 이 경우에 답은 v // u + 2 가 된다.

 

만약 len(v) 가 len(u) 의 배수가 아니라면 문자열 s는 반복되는 문자열로 나타낼 수 없다.  s = s ^ 1 로 밖에 나타낼 수 없기 때문에 이 경우에 답은 1이 된다.

 

 

이 사실을 토대로 문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.read


def sol4354():
    # 각 케이스의 정답을 저장할 리스트
    answer = []
    
    # 케이스의 문자열 s
    for s in input().splitlines():
        # s가 '.'이라면 종료
        if s == '.':
            break
            
        # lps 테이블 전처리
        m = len(s)
        dp = [0]*m
        i = 0
        for j in range(1, m):
            while i > 0 and s[i] != s[j]:
                i = dp[i-1]
            if s[i] == s[j]:
                i += 1
                dp[j] = i
        
        # 공통부분 v
        v = dp[-1] * 2 - m
        
        # 공통부분의 길이가 0보다 작은 경우
        # lps * 2 < m 이기 때문에 답은 1
        if v < 0:
            answer.append(1)
            
        # 공통부분의 길이가 0인 경우
        # lps * 2 == m 이기 때문에 답은 2
        elif v == 0:
            answer.append(2)
            
        # 공통부분의 길이가 1 이상인 경우
        else:
            # lps에서 공통부분을 뺀 길이
            u = dp[-1] - v
            
            # v가 u의 배수가 아니라면 
            # 반복되는 문자열로 나타낼 수 없기 때문에 답은 1
            if v % u:
                answer.append(1)
                
            # v가 u의 배수라면 
            # 양 끝의 u와 중간의 v//u 개의 u를 이어붙인 형태이기 때문에
            # u ^ (v//u + 2) 로 나타낼 수 있음.
            # 답은 v//u + 2
            else:
                answer.append(v//u+2)
                
    # 정답 리스트를 출력형식에 맞춰 반환
    return '\n'.join(map(str, answer))

 

 

2. 광고(https://www.acmicpc.net/problem/1305)

 길이 L의 전광판에 길이 N의 광고문구를 무한히 이어붙인 문자열이 표시될 때,  광고문구가 될수있는 가장 짧은 문자열의 길이를 구하는 문제.  예를들어 길이가 6인 광고판에 aabaaa가 표시된 경우 원래 광고로 추정 가능한 가장 짧은 것은 'aaba', 'aaab', 'baaa', 'abaa' 등이다.

 

 이 문제 또한 KMP 알고리즘의 실패함수를 활용하여 해결할 수 있다. 전광판에 표시된 문자열이 광고를 무한히 이어붙인 형태라면 문자열 내에는 광고의 시작부분이 반복되는 구간이 존재한다.  이를 문자열의 lps라고 생각하면 나머지는 간단하다.  예시로 든 'aabaaa' 의 경우, lps 는 'aa' 가 된다.  그러면 'aabaaa' 에서 suffix에 해당하는 'aa' 는 광고가 이어붙여진 부분이라고 생각할 수 있다. 이렇게 생각할 경우 이 광고의 원래 형태는 'aaba' + 'aaba' + 'aaba' ... 가 된다.  그리고 이것이 가능한 광고의 가장 짧은 형태이다.   즉, 전광판의 길이 L과 문자열 s가 주어졌을 때,  가능한 광고 길이의 최솟값은 L - len(lps) 가 된다.  이 사실을 토대로 문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol1305():
    # 전광판의 크기
    L = int(input())
    
    # 전광판에 표시된 문자열
    adv = input().rstrip()
    
    # LPS 테이블 전처리
    lps = [0] * L
    i = 0
    for j in range(1, L):
        while i > 0 and adv[i] != adv[j]:
            i = lps[i-1]
        if adv[i] == adv[j]:
            i += 1
            lps[j] = i
           
    # L - len(lps) 반환
    return L-lps[-1]

 

 

3. 시계 사진들(https://www.acmicpc.net/problem/10266)

 동일한 길이의 바늘 n 개를 가진 두 시계의 각 바늘이 9시방향과 이루는 각도가 순서없이 주어졌을 때, 두 시계를 회전시켜 같은 시간을 가리키도록 할 수 있는지 구하는 문제.  각도의 단위는 1/1000도(degree)이다.

 

 문제해결의 가장 중요한 아이디어는 바늘을 건드리지않고 두 시계를 회전시켜 같은 시간을 가리키도록 할 수 있으려면 두 시계의 인접한 바늘들이 서로 같은 각도만큼 떨어져있어야 한다는 것이다. 문제를 해결한 절차는 다음과 같다.

 

① 두 시계의 시계바늘의 각도들을 오름차순으로 정렬한다.

 

② 각 시계의 인접한 바늘끼리의 각도의 차를 구한다.(A, B)  이 때, 마지막 각도와 첫 번째 각도의 차도 구해야함에

    유의해야 하며, 마지막 바늘에서 시작하여 첫 바늘로 가는 각도이기 때문에 An = D1 - (Dn - 360000) 와 같이

    계산해야한다.

 

③ 시계를 회전시키는 것으로 같은 시각을 가리키게 할 수 있으려면 인접한 바늘의 각도차 리스트 A, B중 한쪽을

    회전시켜가며 비교해야한다. 

 위 그림의 두 시계를 회전하여 같은 시각을 가리키도록 만들 수 있는지 확인하기 위해 인접한 바늘들의 각도차 리스트

A = [A1, A2, A3, A4, A5, A6] 와 B = [B1, B2, B3, B4, B5, B6] 를 비교해야한다. 그런데 A나 B 한쪽을 회전시켜가며 비교하는 것은 O(N^2) 의 시간복잡도를 가지기 때문에 N의 최대 크기가 20만인 이 문제를 1초안에 해결할 수 없다. 하지만 KMP 알고리즘을 활용하면 이 문제를 간단히 해결할 수 있다.

 

④ 리스트 A를 회전시킨 결과 B와 일치하는 경우가 존재한다면 A를 두번 이어붙인 리스트에 B와 일치하는 구간이

    존재한다고도 생각할 수 있다. 즉, 리스트 A를 두번 이어붙인 리스트에서 리스트 B가 등장하는지 KMP 알고리즘으로

    검사한다면 O(N)의 시간복잡도로 이 작업을 해결할 수 있다.

 

위 사실을 토대로 문제를 해결한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol10266():
    # 시계의 바늘 갯수
    n = int(input())
    
    # 두 시계의 바늘의 각도를 오름차순 정렬
    x = sorted(map(int, input().split()))
    y = sorted(map(int, input().split()))

    # 인접한 바늘끼리의 각도차 리스트를 구함
    a = ([x[(i + 1)] - x[i] for i in range(n - 1)] + [x[0] - (x[-1] - 360000)]) * 2
    b = [y[(i + 1)] - y[i] for i in range(n - 1)] + [y[0] - (y[-1] - 360000)]
    n, m = 2 * n, n

    # LPS 테이블 전처리
    lps = [0] * (m + 1)
    i = 0
    for j in range(1, m):
        while i > 0 and b[i] != b[j]:
            i = lps[i - 1]
        if b[i] == b[j]:
            i += 1
            lps[j] = i

    # KMP 알고리즘으로 문자열 탐색 시작
    i, j = 0, 0
    while i < n:
        if a[i] == b[j]:
            i += 1
            j += 1
            # 탐색에 성공했다면 같은 시각을 나타내도록 할 수 있음
            # 'possibe' 반환
            if j == m:
                return 'possible'
        else:
            if j == 0:
                i += 1
            j = lps[j - 1]
    
    # 모든 작업을 마칠동안 탐색에 성공하지 못했다면
    # 같은시각을 나타내도록 할 수 없음 
    # 'impossible' 반환
    return 'impossible'

 

KMP 알고리즘을 통한 문자열 검색 자체도 활용도가 높지만 실패함수도 활용도가 상당히 높다는걸 알 수 있었다.

필요하다면 언제든지 구현할 수 있도록 기억해두는 것이 좋겠다.

 문서 편집기나 뷰어 등에서 탐색 기능을 사용하면 문서내의 모든 문자열 속에서 찾으려는 문자열의 갯수와 위치를 보여준다.  전체 문자열의 길이가 N, 찾으려는 문자열의 길이가 M 일때, 단순히 시작위치를 0부터 1씩 옮겨가며 M개의 문자를 비교한다면 시간복잡도는 (N-M+1) * M, 즉 O(NM) 이 된다.   하지만 사실 이보다 더 빠르게 문자열을 탐색할 수 있는 방법이 있다.  이번에는 문자열 속에서 다른 문자열 P의 갯수와 위치를 찾는 KMP 알고리즘을 알아본다.

 

1. KMP 알고리즘의 개요

 KMP 알고리즘의 핵심은 쓸모없는 중간과정을 건너뛰는 것으로 연산의 횟수를 획기적으로 줄이는 것이다.

예를들어 T = 'ABCDABCDABDE' 이고 P = 'ABCDABD일 때, T에 P가 나타나는 횟수와 그 위치들을 찾아보자.

 

먼저 시작 위치를 0 번째 위치로 하고 T와 P를 순차적으로 비교하면 다음과 같은 결과를 얻을 수 있다.

T A B C D A B C D A B D E
P A B C D A B D          
  O O O O O O X          

 5 번째 문자까지는 일치했지만 6 번째 문자에서 일치하지 않은 것을 확인할 수 있다. 단순한 방법이라면 여기서 탐색 위치를 0으로 옮겨 또다시 M번의 탐색을 해야한다.  하지만 우리는 사실 이 결과에서 6 번째 문자가 일치하지 않았다는 점 보다도 5 번째 문자까지는 일치했다는 점을 활용해야한다

 

 일치한 부분까지의 문자열은 ABCDAB 이다.  여기서 양 끝에서 반복적으로 나타나는 가장 긴 문자열 AB에 주목하자. 이러한 문자열을 문자열의 접두사/접미사(prefix/suffix)라 한다. AB는 문자열 P의 시작부분이며 동시에 T와 일치했던 부분의 끝부분이기도 하다.  이 구간에서 AB가 다시한번 나타나는 4번째 위치 이전까지는 시작위치로 잡고 탐색하더라도 일치할 가능성이 없다.  즉, 다음 탐색의 시작위치를 곧바로 AB가 다시한번 나타나는 구간으로 건너뛸 수 있다는 것이다.  건너뛴 다음의 모습은 다음과 같다.

 

T A B C D A B C D A B D E
P         A B C D A B D  
          O O O O O O O  

물론 여기서도 P[0] 부터 다시 비교할이유는 없다.  우린 이미 AB까지는 일치함을 알고있기 때문이다. 즉 다음에 탐색해야할 P의 요소는 P[2] 가 되며 T의 요소는 이전에 불일치가 발생한 부분인 T[6]이 된다.

 

이를 일반화한 절차는 다음과 같다.

 

① 길이가 각각 N, M 인 문자열 T와 P에 대해 탐색을 진행, 리스트 L에 문자열 P가 등장하는 모든 위치를 담는다.

 

② T와 P의 초기 인덱스 i, j는 각각 0으로 초기화

 

③ T[i]와 P[j] 가 같다면 두 인덱스를 각각 +1 

    => 만약 그 결과 j가 M이 됐다면(모든 문자가 일치했다면) 찾아낸 구간의 시작위치(i - M)를 리스트 L에 append

    => j를 P[0~j-1] 구간의 접두사/접미사의 길이로 변경. 반복되는 부분을 생략하고

        그 다음 요소부터 검사하기 위함

 

④ T[i]와 P[j] 가 다르다면

    => 만약 j가 0이라면(P문자열의 시작부터 일치하지 않았다면) T[i] 로 시작하는 문자열과는 일치할 수 없기 때문에

         i += 1

    => j를 P[0~j-1] 구간의 양 끝의 가장 긴 반복문자열(위 예시 기준 AB) 의 길이로 설정. 반복되는 부분을 생략하고

        그 다음 요소부터 검사하기 위함

 

⑤ i의 인덱스가 N 이상일 경우 더이상 탐색할 수 없기 때문에 작업 종료

 

⑥ 리스트 L에 담긴 숫자들이 문자열 P가 등장한 모든 인덱스를 나타내며 L의 길이가 P가 나타난 횟수가 된다.

 

 

2. LPS(Longest Proper prefix and Suffix) 테이블

 KMP 알고리즘의 각 단계는 기본적으로 i가 움직이거나 문자열 P의 비교 시작 위치가 움직이기 때문에 O(N+M)의 시간복잡도를 보인다.  다만 여기에는 하나의 전제가 필요하다.  바로 j를 P[1~j-1] 의 접두사/접미사 길이로 변경해주는 작업이 O(1)이어야 한다는 것이다.  물론 정말로 O(1)로 이 작업을 수행할 수는 없지만, 알고리즘을 수행하는 반복문 밖에서 O(M)의 시간복잡도로 LPS테이블을 구해두는 전처리를 하는 것으로 반복문 내에서는 O(1)의 시간복잡도를 보일 수 있다. LPS 테이블을 구하는 절차는 다음과 같다.

 

① 먼저 dp = [0] * (m+1) 으로  LPS 테이블 초기화

② i를 P 문자열의 첫 위치 0으로 초기화

③ P문자열의 j 번째 요소(j=1~m-1) 비교 (④~⑤ 반복)

④ P[i] == P[j] 라면 i를 1 증가시키고 dp[j] = i  (반복구간의 길이가 증가)

⑤ P[i] != P[j] 일 경우

    => 1) i가 0이라면 애초에 반복구간이 아직 존재하지 않기 때문에 단순히 pass

    => 2) i가 0보다 크다면 i를 0이되거나 P[i] == P[j] 가 될 때까지 이전에 P[j]와 일치했던 구간으로 이동 (i = dp[i-1])

⑥ 작업이 모두 끝나면 dp[k] 는 P[1~k] 구간의 LPS값이 된다. 

 

 

3. 구현

 지금까지의 내용을 토대로 실제로 KMP 알고리즘으로 문자열 탐색 문제(https://www.acmicpc.net/problem/1786)를 해결한 코드는 다음과 같다.  

 

import sys

input = sys.stdin.readline


def sol1786():
    # 전체 문자열 T
    T = input().replace('\n', '')
    
    # 찾을 문자열 P
    P = input().replace('\n', '')
    
    # 문자열 T, P의 길이 n, m
    n, m = len(T), len(P)

    # LPS 테이블
    dp = [0] * (m + 1)
    
    # P의 첫 위치이자 suffix의 길이
    i = 0
    
    # dp[j] 는 P[:j+1]의 LPS값
    for j in range(1, m):
        # 만약 P[i]와 P[j] 가 다르고 0보다 크다면
        # P[i]==P[j]가 될 때까지 i를 이전에 suffix가 이어졌던 부분까지 이동
        while i > 0 and P[i] != P[j]:
            i = dp[i-1]
            
        # P[i]와 P[j]가 같다면 suffix의 길이를 1 증가
        # dp[j] 에 suffix의 길이를 저장
        if P[i] == P[j]:
            i += 1
            dp[j] = i

    # 문자열 P가 발견된 위치를 모두 저장할 리스트
    answer = []
    
    # 문자열 T, P의 탐색 위치
    i, j = 0, 0
    
    # i가 문자열 T의 길이에 도달하기 전 까지 반복
    while i < n:
        # T[i]와 P[j] 가 같다면 두 인덱스 모두 1씩 증가
        if T[i] == P[j]:
            i += 1
            j += 1
            
            # 인덱스가 증가한 결과 j가 m에 도달했다면
            # P와 일치하는 문자열을 발견한 것이기 때문에 answer에 append
            # j는 LPS 테이블을 참조하여 다음 탐색 위치(dp[j-1])로 이동
            # 다음 탐색에서는 이미 일치할 것을 알고있는 부분을 생략하고 탐색
            if j == m:
                answer.append(i-m+1)
                j = dp[j-1]
                
        # T[i] 와 P[j] 가 다르다면
        else:
            # 만약 P[0] 부터 일치하지 않았다면 T[i]와의 비교는 무의미하기 때문에 i += 1
            if j == 0:
                i += 1
                
            # j는 LPS 테이블을 참조하여 다음 탐색 위치로 이동
            j = dp[j - 1]
            
    # 문자열 P가 등장한 횟수와 위치를 출력형식에 맞춰 반환
    return '\n'.join(map(str, [len(answer), ' '.join(map(str, answer))]))

 

 

※ LPS 테이블 주의사항

 LPS 테이블을 생성할 때, 처음에는 P[i] != P[j] 일 경우 i를 0으로 초기화후 다시 비교하는 방식을 사용했지만 오답이 나왔다. 이 방식의 문제는 aacaaa 와 같은 문자열에 대해 LPS 테이블을 생성할 때 발생한다. 위 코드처럼 정상적인 방식으로 LPS 테이블을 생성할 경우 0 1 0 1 2 2 가 된다.  하지만 P[i] != P[j] 일 때 i를 0으로 초기화하고 다시 비교할 경우, aacaa 까지는 0 1 0 1 2 로 동일하지만, i가 2인 상태에서 P[2] != P[5] 로 i가 0으로 초기화되어 다시 비교하면 P[0] == P[5]로 dp[5]가 1이 되어버린다.  잘못 구현한 LPS 테이블을 구하는 코드는 다음과 같다.

dp = [0] * (m + 1)
i = 0
for j in range(1, m):
    if P[i] == P[j]:
        dp[j] = dp[j-1] + 1
        i += 1
    else:
        if P[j] == P[0]:
            dp[j] = i = 1
        else:
            i = 0

 

 동적 계획법을 적용하여 경우의 수를 구하는 문제를 알아본다.

 

1. RGB거리 2(https://www.acmicpc.net/problem/17404)

 n 개의 집이 직선상에 번호순서대로 있고 각 집을 빨강, 초록, 파랑의 색으로 칠하기 위해 필요한 비용이 주어졌을 때, 인접한 집 끼리는 색이 같지 않도록 칠하는 경우의 수를 구하는 문제.  단, 1번 집과 n번 집의 색도 서로 달라야한다.

 

동적계획법을 적용한 완전탐색 문제이다. 집을 순차적으로 방문하여 이전 집과 겹치지 않는 색상 중 어느 색상으로 칠해야 비용을 최소화 할 수 있는지 재귀적으로 구해나가면 간단하게 해결할 수 있다. 주의할 점은 마지막 집이 첫 번째 집과도 색이 겹쳐선 안된다는 조건이 있기에 첫 번째 집의 색상을 저장해두고 마지막 집의 색상을 정할 때 반영해야한다는 것이다.  문제를 해결한 코드는 다음과 같다.

 

import sys

sys.setrecursionlimit(100000)
input = sys.stdin.readline
INF = 10**9


def sol17404():
    # 집의 갯수
    n = int(input())
    
    # 각 집을 빨강, 초록, 파랑으로 칠하기 위한 비용
    cost = [list(map(int, input().split())) for _ in range(n)]
    
    # dp[i][j] 는 i 번째 집의 색이 j 일 때 최소 누적비용
    dp = [[0] * 3 for _ in range(n)]
    
    # 완전탐색 함수 - 이전집의 색과 현재 집의 번호를 인자로 받는다
    def dfs(pre, cur):
        # 첫 번째 집의 색
        nonlocal f
        
        # 마지막 집까지 모두 칠했다면 더이상 비용은 들지않는다.
        if cur == n:
            return 0
            
        # 현재 집을 이전 집과(마지막 집의 경우 첫 번째 집과도) 색이 겹치지 않도록 칠할 때
        # 그 비용을 최소화시킬 수 있는 경우를 구한다
        res = INF
        for c in range(3):
            if (c == pre) or (cur == n-1 and c == f):
                continue
            if not dp[cur][c]:
                dp[cur][c] = dfs(c, cur+1) + cost[cur][c]
            res = min(res, dp[cur][c])
        
        # 현재 집의 색에 따른 총 비용 중 최솟값
        return res
    
    # 첫 번째 집의 색에 따라 세 번 탐색하여 최소비용을 구함
    answer = INF
    for f in range(3):
        dp = [[0] * 3 for _ in range(n)]
        answer = min(answer, dfs(f, 1)+cost[0][f])
        
    # 최소비용 반환
    return answer

 

 

2. 색상환(https://www.acmicpc.net/problem/2482)

 n 개의 색상이 있는 색상환에서 k개의 색을 서로 인접하지 않도록 고르는 경우의 수를 구하는 문제.  동적계획법을 활용한 풀이와 조합을 활용한 풀이 두 가지 방법이 있다. 

 

1) 동적 계획법을 활용한 풀이

 양 끝이 인접해있는 원형임을 일단 생각하지 않고 선형의 리스트에서 인접하지 않도록 k개를 뽑는 경우를 생각해보자.

dp[i][j] 가 i 번째 까지의 색상리스트에서 j 개의 색을 인접하지 않도록 뽑는 경우의 수라고 할 때, 다음과 같은 점화식이 성립한다.

dp[i][j] = dp[i-1][j] + dp[i-2][j-1]

위와 같은 점화식이 성립하는 이유는 현재 보고있는 i 번째 색을 선택할지, 선택하지 않을지 여부에 따라 두 가지 경우의 수로 나뉘기 때문이다.   

 

① i 번째 색을 선택할 경우 이전의 색은 선택될 수 없다. 또한 이번 색을 선택함으로서 j개의 선택이 완료되어야

   하기 때문에 그 전까지 j-1 개의 선택이 완료된 상태여야 한다. (dp[i-2][j-1])

 

② i 번째 색을 선택하지 않을 경우 이전의 색도 선택될 수 있으며, 이번에 색을 선택하지 않기 때문에 그 전까지

   이미 j 개의 선택이 모두 완료된 상태여야 한다. (dp[i-1][j-2])

 

 

이제 첫 번째 색과 마지막 색이 인접해서는 안된다는 조건을 처리해야한다.

 

① 마지막 색을 선택할 경우, 자기 자신과 좌우로 두 개의 인접한 색을 제외한 범위에서

    k-1개의 색을 선택한 상태여야 한다. (dp[n-3][k-1])

 

② 마지막 색을 선택하지 않을 경우, 그 전까지 j-1 개의 선택을 완료한 상태여야 한다. (dp[n-1][k])

 

이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.read
mod = 1000000003


def sol2482():
    # 색상의 갯수와 골라야할 색상의 수
    n, k = map(int, input().split())
    
    # dp[i][j] 는 i번까지의 색상리스트중 j개를 고르는 경우의 수
    dp = [[0] * (k + 1) for _ in range(n + 1)]
    
    # 0개를 고르는 경우는 항상 1이며 1개를 고르는 경우는 항상 색상의 갯수와 같다.
    for i in range(n + 1):
        dp[i][0] = 1
        dp[i][1] = i
	
    # 점화식에 따라 dp를 채워나간다
    for i in range(2, n + 1):
        for j in range(2, k + 1):
            dp[i][j] = (dp[i - 2][j - 1] + dp[i - 1][j]) % mod

    # 양 끝이 인접할 경우를 고려하여 경우의 수를 구하여 반환한다
    return (dp[n - 3][k - 1] + dp[n - 1][k]) % mod

 

2) 조합을 활용한 풀이

 n 개의 색상에서 인접하지 않도록 색상을 최대한 많이 고르는 경우는 한칸씩 띄어가며 선택하는 것이다.  즉, 고르려는 색상의 갯수 k 가 n의 절반을 넘을 경우 인접하지 않도록 색상을 고르는 것은 불가능하다. 

if k > n//2:
    return 0

k가 n의 절반 이하라면 경우의 수는 한 개이상 존재한다. 이 때, 그 경우의 수를 다음과 같이 생각해볼 수 있다.

n개의 색상칸에 선택할 색상 k 개를 제외한 선택하지 않을 n-k 개의 색을 먼저 나열해놨다고 할 때(녹색칸), 위 그림과 같이 선택하지 않은 칸을 사이에 두고 슬롯이 n-k+1 개 생긴다.(하얀칸)  선택받은 k개의 색상을 슬롯에 넣는 경우의 수를 생각하면 combination(n-k+1, k) 가 된다.  그런데 이 표는 편의상 리스트형태로 표현했을 뿐 환형이기 때문에 만약 k개의 슬롯을 선택했을 때 양 끝이 모두 선택된다면 양 끝이 맞닿는 문제가 발생한다. 그 경우의 수는 양 끝의 슬롯을 이미 선택했다고 가정하고 남은 n-k-1 개의 슬롯에서 k-2 개의 슬롯을 고르는 경우의 수, 즉 combination(n-k-1, k-2) 가 된다.

둘의 차를 구하면 k개의 색상이 서로 인접하지 않도록 선택하는 모든 경우의 수를 구할 수 있다. 이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.read
mod = 1000000003


def sol2482():
    # 색상의 총 수 n, 골라야할 색상의 수 k
    n, k = map(int, input().split())
    
    # 골라야할 색상의 수가 총 색상 수의 절반을 넘기면 인접하지 않도록 고르는것은 불가능
    if k > n//2:
        return 0
        
    # 1개의 색을 고르는 경우의 수는 총 색상 수와 같음
    if k == 1:
        return n
        
    # n-k개의 고르지 않은 수의 사이에 존재하는 슬롯 n-k+1 개에서 k개의 슬롯을 골라 수를 삽입
    # 양 끝 슬롯을 함께고르는 경우를 배제하기 위해 양 끝 슬롯을 제외한 n-k-1 개의 슬롯에서 
    # k-2개의 슬롯을 고르는 경우의 수를 감산
    return (bc(n-k+1, k) - bc(n-k-1, k-2)) % mod
    
  
# 이항계수를 반환하는 함수
def bc(a, b):
    b = min(b, a-b)
    u, v = 1, 1
    for i in range(b):
        u *= (a - i)
        v *= (1 + i)
    return u//v

이러한 풀이는 실전에서 떠올릴 수 있다면 좋겠지만 너무 몰입하다간 문제해결에 과하게 시간을 낭비할 수 있기 때문에 조금 시도해보다가 답이 보이지 않는다면 과감하게 포기하고 다른 해결책을 모색해보는 것도 좋을 것 같다.

 

+ Recent posts