여러개의 정점이 주어졌을 때, 간선들의 가중치의 합이 최소가 되도록 하면서 모든 정점을 연결한 형태를

최소 신장 트리(Minimum Spanning Tree)라고 한다.  가중치는 실제 문제에서 보통 거리, 비용 등의 형태로 주어진다. 

이번에는 최소 신장 트리를 생성하는 대표적인 두 가지 알고리즘을 알아본다.

 

1. Kruskal's Algorithm

 Kruskal's Algorithm은 n개의 정점이 주어졌을 때 n-1개의 간선을 선택하는 방식으로 MST를 구한다.

절차는 아래와 같다.

 

① 모든 간선을 가중치 기준 오름차순 정렬한다.

② 가장 가중치가 작은 간선부터 선택하여 트리를 만들어나간다.

③ 만약 간선이 이미 연결된 정점 간의 간선이라면 필요가 없기 때문에 넘어간다.

④ ②,③을 간선이 n-1개가 되거나 더이상 간선이 존재하지 않을 때 까지 계속한다.

⑤ n-1개의 간선을 선택하는 데 성공했다면 그 간선들이 MST를 이루는 간선들이 된다.

⑥ 그렇지 않다면 이 그래프에서는 MST를 만드는 것이 불가능하다.

 

이 과정에서 이미 연결된 정점간의 간선인지의 판단은 기존에 알아본 바 있는 Disjoint Set 알고리즘이 활용된다.  이를 구현하여 최소신장트리 문제를 해결한 코드는 다음과 같다.

import sys

input = sys.stdin.readline


def sol1197():
    v, e = map(int, input().split())
    edges = [list(map(int, input().split())) for _ in range(e)]
    # 1
    edges.sort(key = lambda x:x[2])
    u = [-1] * (v+1)
    cost = 0
    # 4
    for a, b, c in edges:
    	# 2, 3
        if union(u, a, b):
            cost += c
    # 5
    return cost

   
def union(u, a, b):
    a = find(u, a)
    b = find(u, b)
    if a == b:
        return False
    if u[a] < u[b]:
        u[a] += u[b]
        u[b] = a
    else:
        u[b] += u[a]
        u[a] = b
    return True
   

def find(u, x):
    if u[x] < 0:
        return x
    u[x] = find(u, u[x])
    return u[x]
    

if __name__ == '__main__':
    print(sol1197())

 Kruskal's Algorithm은 모든 간선을 정렬하는 것이 필수이기 때문에 정점의 갯수에 비해 간선의 갯수가 적은 케이스에 사용하기 좋다.

 

 

2. Prim's Algorithm

 Prim's Algorithm은 n개의 정점이 주어졌을 때 가장 가까운 정점부터 하나씩 추가해나가는 방식으로 최소신장트리를 구한다. 절차는 아래와 같다.

 

① 정점 하나를 최소 신장 트리의 첫 정점으로 잡고 트리로부터 각 정점으로의 거리를 구한다.(초기값 INF)

② 트리에 포함되지 않은 모든 정점 중 가장 트리에 가까운 것을 선택하여 트리에 포함시키고 가중치를 더한다.

③ 추가된 정점으로부터 다른 정점까지의 거리로 트리로부터 각 정점으로의 거리를 갱신한다.

④ ②, ③을 n-1 번 반복한다.

⑤ 가중치의 합이 INF가 아니라면 최소 신장 트리가 만들어진다.  그렇지 않다면 최소 신장 트리를 만들 수 없다.

 

이를 구현하여 최소 신장 트리 문제(https://www.acmicpc.net/problem/20390)를 해결한 코드는 다음과 같다.

import sys


def sol20390():
    input = sys.stdin.read
    n, a, b, c, d, *vertex = map(int, input().split())
    INF = 2000000000000

    include = [False] * n
    inq = [INF] * n
    q = [(0, 0)]
    include[0] = True
    inq[0] = 0
    res = 0
    # 1
    for i in range(1, n):
        inq[i] = ((vertex[0] * a + vertex[i] * b) % c) ^ d
    # 4
    for i in range(n - 1):
        v, dist = 0, INF
        # 2
        for j in range(n):
            if not include[j] and inq[j] < dist:
                v, dist = j, inq[j]
        include[v] = True
        res += dist
        
        # 3
        for j in range(n):
            if not include[j]:
                dst = ((vertex[v] * a + vertex[j] * b) % c) ^ d if v < j else ((vertex[j] * a + vertex[v] * b) % c) ^ d
                inq[j] = min(inq[j], dst)
    # 5
    print(res)


if __name__ == '__main__':
    sol20390()

 2, 3 과정은 heapq를 사용하여 더 간단하게 만들 수도 있지만 이 문제는 메모리 제한이 16mb이기에 정석대로 트리로부터의 최단거리 배열과 반복문을 사용하여 해결하였다. Python3으로는 시간이 부족하기에 PyPy3으로 제출하였다.

'코딩테스트 > 알고리즘' 카테고리의 다른 글

#10 정수론 - 나머지의 성질  (0) 2021.08.10
#9 정수론 - 최대공약수와 최소공배수  (0) 2021.08.09
#7 투 포인터(Two Pointers)  (0) 2021.08.02
#6 LCS 알고리즘  (0) 2021.07.25
#5 Disjoint Set 알고리즘  (0) 2021.07.24

+ Recent posts