코딩테스트/알고리즘

#37 LCA(Lowest Common Ancestor) 알고리즘

Scala0114 2021. 9. 18. 16:03

 트리에서 두 노드의 가장 가까운 공통 조상을 보다 효율적으로 구하기 위한 LCA 알고리즘에 대해 관련 문제를 해결하며 알아본다. 

 

1. 가장 가까운 공통 조상(https://www.acmicpc.net/problem/3584)

 가장 기본적인 LCA 문제이다.  노드의 수가 최대 10,000이며 쿼리의 수도 하나 뿐이기 때문에 효율적인 알고리즘을 생각할 필요 없이 간선 정보를 토대로 각 노드의 부모노드를 저장하고 쿼리로 주어진 두 노드로부터 부모노드를 탐색해나가며 각 노드의 모든 조상노드 리스트를 구하고, 두 리스트를 뒤에서부터 비교해나갔을 때 마지막으로 같았던 두 리스트의 값이 가장 가까운 공통 조상 노드가 된다.  이를 구현한 코드는 다음과 같다.

import sys

input = sys.stdin.readline


# 3584 가장 가까운 공통 조상
# 노드의 수가 적고 쿼리도 케이스당 하나 뿐이기 때문에
# 단순한 방식으로도 해결 가능
def sol3584():
    # 케이스별 정답 리스트
    answer = []
    for t in range(int(input())):
        # 노드의 수
        n = int(input())

        # 각 노드의 부모노드
        parent = [0] * (n + 1)

        # 부모노드와 자식노드를 입력받아 부모노드를 채운다
        for _ in range(n - 1):
            a, b = map(int, input().split())
            parent[b] = a

        # 가장 가까운 공통조상을 구할 두 노드
        u, v = map(int, input().split())

        # 두 노드의 조상 노드 리스트를 구한다.
        up, vp = [], []
        while u:
            up.append(u)
            u = parent[u]
        while v:
            vp.append(v)
            v = parent[v]

        # 두 노드의 조상 노드 리스트를 뒤집는다.
        up, vp = up[::-1], vp[::-1]

        # 루트 노드부터 시작하여 두 노드의 조상 노드가 서로 달라질 때 까지 탐색
        # 두 노드의 조상 노드가 마지막으로 같았을 때의 조상 노드가 가장 가까운 공통 조상 노드가 된다.
        res = 0
        for i in range(min(len(up), len(vp))):
            if up[i] == vp[i]:
                res = up[i]
            else:
                break

        # 정답 리스트에 결과 삽입
        answer.append(res)

    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 

하지만 위의 방법은 당연하게도 노드의 수와 쿼리의 수가 많아지면 쿼리마다 사용할 수 없다.

 

 

이번엔 다른 방식으로 문제를 해결해보자.  절차는 다음과 같다.

 

① 먼저 트리를 탐색하여 모든 노드의 깊이와 부모노드를 구한다.

② 주어진 두 노드의 깊이가 다르다면 보다 깊이 있는 노드가 다른 노드와 같은 높이가 될 때 까지 거슬러 올라온다.

③ 두 노드가 서로 같아질 때 까지 두 노드 모두 부모노드를 타고 거슬러 올라온다.

④ 두 노드가 서로 같아졌을 때, 그 노드가 가장 가까운 조상 노드가 된다.

 

이를 구현한 코드는 다음과 같다.

 

import sys

input = sys.stdin.readline


def sol3584():
    # 정답 리스트
    answer = []
    
    # 테스트케이스
    for t in range(int(input())):
        # 노드의 수
        n = int(input())
        
        # 그래프 생성, 부모 노드 구하기
        g = [[] for _ in range(n + 1)]
        parent = [0] * (n + 1)
        for _ in range(n-1):
            a, b = map(int, input().split())
            g[a].append(b)
            parent[b] = a
        
        # 루트 노드로부터 bfs탐색으로 각 노드의 깊이를 구한다
        depth = [0] * (n + 1)
        root = [i for i in range(1, n+1) if not parent[i]][0]
        
        q = [root]
        while q:
            nq = []
            for cur in q:
                for c in g[cur]:
                    depth[c] = depth[cur] + 1
                    nq.append(c)
            q = nq
            
        # 주어진 두 노드 u, v
        u, v = map(int, input().split())
        
        # 노드 u의 깊이가 노드 v의 깊이보다 크도록 한다.
        if depth[u] < depth[v]:
            u, v = v, u
            
        # 두 노드의 깊이차가 0이 될 때 까지 깊은 쪽 노드인 u를 끌어올린다
        while depth[u] - depth[v]:
            u = parent[u]
           
        # 두 노드가 서로 같아질 때 까지 두 노드를 끌어올린다
        while u != v:
            u, v = parent[u], parent[v]
          
        # u와 v가 서로 같아지면 u == v == 가장 가까운 공통 조상 노드 가 된다.
        answer.append(u)
       
    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

이 방식 자체로는 보다 많은 노드와 쿼리 수가 주어지는 문제를 해결하기에 딱히 이전의 알고리즘보다 효율적이지 않다.  하지만 이 방식에 sparse table을 적용하면 얘기가 달라진다.  다음 문제에서 이 sparse table에 대해 다룬다.

 

2. 합성함수와 쿼리(https://www.acmicpc.net/problem/17435)

 함수 f(x)의 값들이 주어졌을때, fn(x) 가 f(f(f(...f(x)...))) 와 같이 함수 f를 n중첩한 결과라고 한다.  이 때, 주어진 m개의

n, x 쌍에 대해 fn(x) 를 각각 구하는 문제이다.  LCA 문제는 아니지만 LCA 알고리즘의 핵심인 sparse table을 응용해서

풀 수 있는 문제이다. 절차는 다음과 같다.

 

① 먼저 함수 f의 값을 저장할 2차원 배열을 생성한다.  이 때, f[x][k] 는 함수 f의 k중첩에 x를 넣은 결과이다.  예를들어,

    f[3][4] = f(f(f(f(3)))) 이 된다.

 

② f[x][0] 의 값을 입력받은 f(x)의 값으로 초기화한다.

 

③ f[i][j] 은 f[i][j-1] 을 2^(j-1) 중첩한 값이 되기 때문에, 점화식 f[i][j] = f[f[i][j-1]][j-1] 이 성립한다. 이 점화식에 따라 

    나머지 f[x][1~k] 를 채운다.  여기서 k는 함수 중첩수 최댓값인 n에 로그를 취한 뒤 올림한 값이 된다. 

    => math.ceil(math.log2(n))

 

④ 이제 쿼리로 n, x가 주어지면 ③ 에서 구한 sparse table을 이용하여 빠르게 목표한 값으로 접근한다. 

    n이 0보다 큰 동안, x = f[x][int(log2(n))];  n -= (1 << int(log2(n))) 을 수행하면 가능하다.  n 중첩을

    1씩 좁혀나가는 것이 아니라 2의 승수단위로 좁혀나가는 것으로 탐색시간을 극적으로 줄이는 방식이다.

 

⑤ ④의 작업이 모두 끝났을 때, x의 값이 해당 쿼리에 대한 정답이 된다.

import sys
from math import log2, ceil

input = sys.stdin.readline



def sol17435():
    # 주어진 함수값의 수
    m = int(input())
    
    # 500000 <= 2^k
    k = ceil(log2(500000))
    
    # sparse table :  f[x][k] = fk(x)
    f = [[-1] * k for _ in range(m + 1)]
    fn = list(map(int, input().split()))
    for i in range(m):
        f[i + 1][0] = fn[i]

    # 점화식 f[i][j] = f[f[i][j-1]][j-1] 에 따라 sparse table 나머지 채우기
    for j in range(1, k):
        for i in range(1, m + 1):
            f[i][j] = f[f[i][j - 1]][j - 1]

    # 쿼리의 수
    q = int(input())
    
    # 각 쿼리의 정답 구하기
    answer = []
    for _ in range(q):
        # 함수 중첩수 n, 인자값 x
        n, x = map(int, input().split())
        
        # 중첩수가 0이 되면 종료
        while n:
            # f[x][d] == f2^d(x)
            d = int(log2(n))
            x = f[x][d]
            
            # 2^d 중첩만큼 처리했기 때문에 남은 중첩수에서 감산
            n -= (1 << d)
            
        # 정답리스트에 쿼리의 정답 저장
        answer.append(x)
       
    # 출력 형식에 맞춰 정답 리스트 반환
    return '\n'.join(map(str, answer))

 sparse table은 말그대로 어떤 규칙에 따른 값을 지수단위로 듬성듬성하게 구해두고 원하는 값까지 로그시간으로

빠르게 접근하기 위한 자료구조이다.  이 문제의 경우 sparse table 은 구성하는데에 O(MlogN)의 시간복잡도를,

쿼리 하나를 처리하는데 O(logN) 의 시간복잡도를 보인다.(경우에 따라서는 O(1) 수준의 성능을 보일때도 있다고 한다.)

sparse table은 정적인 데이터에서의 쿼리만을 처리할 수 있기 때문에 동적 데이터(데이터가 삽입, 삭제, 변경되는 경우)

에는 대응이 불가능하다.  만약 정적인 데이터에 대해 대량의 쿼리를 처리해야하는 경우라면 sparse table을 사용하는

것을 고려해볼만 할 것 같다.   

 

 

3. LCA 2 (https://www.acmicpc.net/problem/11438)

 이제 1번 문제에서 두 번째로 사용했던 알고리즘에 sparse table을 적용할 차례다.  노드의 수가 최대 100,000개로

늘어나고 쿼리의 수도 최대 100,000개이기 때문에 O(NM)의 시간복잡도를 보이는 기존의 방식으로는 해결할 수 없다.

두 노드의 깊이를 동일하게 맞추는 작업과 동일한 깊이에 있는 두 노드의 가장 가까운 조상 노드를 찾아가는 작업을 

sparse table을 사용하여 O(logN) 수준으로 최적화한다.  이를 구현한 코드는 다음과 같다.

 

import sys
import math

input = sys.stdin.readline


# sparse table(희소 테이블)을 활용하여 문제를 더 효율적으로 해결할 수 있다. 절차는 다음과 같다.
# 1. 각 노드의 깊이를 나타낼 depth 리스트와 부모노드를 나타낼 parent 리스트를 생성한다
#    이 때, parent 리스트는 2차원 리스트이며 parent[i][k] 는 노드 i의 2^k 번 거슬러올라간 조상노드를 의미한다.
#
# 2. dfs 혹은 bfs 로 트리를 순회하며 각 노드의 1번 거슬러올라간 조상노드(parent[i][0])와 깊이를 구한다.
#    이 때, 점화식 parent[i][j] = parent[parent[i][j-1]][j-1] 를 사용해서 O(NK) 의 시간 복잡도로 작업을 수행한다.
#    K는 log(2, N) 을 올림한 값과 같다.
#
# 3. 각 질의(LCA 를 구할 두 노드 u, v) 에 대해 두 노드의 깊이가 같아질 때 까지
#    보다 깊이가 큰 노드를 끌어올린다.
#
# 4. 두 노드가 서로 같다면 그대로 어느 한쪽 노드를 출력
#
# 5. 그렇지 않다면 두 노드의 가장 먼 공통 조상으로부터 처음으로 두 노드의 조상 노드가 달라질 때 까지 탐색
#    탐색 종료 후 노드 u 또는 v의 부모노드를 출력
def sol11438():
    # 각 노드의 깊이와 첫 번째 부모를 구하기 위한 탐색 함수
    def bfs(root):
        q = [root]
        while q:
            nq = []
            for cur in q:
                for c in g[cur]:
                    if depth[c] < 0:
                        depth[c] = depth[cur] + 1
                        parent[c][0] = cur
                        nq.append(c)
            q = nq

    # 노드의 수
    n = int(input())

    # sparse table 의 크기
    k = math.ceil(math.log2(n))

    # 그래프 생성
    g = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        a, b = map(int, input().split())
        g[a].append(b)
        g[b].append(a)

    # 각 노드의 깊이 리스트
    depth = [-1] * (n + 1)

    # 각 노드의 조상 노드 리스트 (sparse table)
    parent = [[-1] * k for _ in range(n + 1)]

    # 루트 노드인 1의 깊이를 0으로 하고 탐색 시작
    depth[1] = 0
    bfs(1)

    # sparse table 채우기
    for j in range(1, k):
        for i in range(1, n + 1):
            parent[i][j] = parent[parent[i][j - 1]][j - 1]

    # 각 질의에 대해 정답 구하기
    answer = []
    for _ in range(int(input())):
        u, v = map(int, input().split())

        # 노드 u가 노드 v 보다 깊이가 크도록 한다
        if depth[u] < depth[v]:
            u, v = v, u

        # 두 노드의 깊이가 같아질 때 까지 노드 u를 끌어올린다
        while depth[u] - depth[v]:
            u = parent[u][int(math.log2(depth[u] - depth[v]))]

        # 두 노드가 같지 않다면
        if u != v:
            # 가장 먼 공통 조상 노드로부터 처음으로 두 노드의 조상 노드가 달라질 때 까지 탐색
            for j in range(math.ceil(math.log2(depth[u])), -1, -1):
                if parent[u][j] != parent[v][j]:
                    u = parent[u][j]
                    v = parent[v][j]

            # 현재 u와 v는 공통 조상노드의 자식 노드인 상태이므로 둘 중 하나는 부모노드를 방문
            u = parent[u][0]

        # 노드 u 혹은 v를 반환
        answer.append(u)

    # 출력 형식에 맞춰 정답리스트 반환
    return '\n'.join(map(str, answer))