Disjoint Set 은 서로 중복되지 않는 부분집합들을 저장, 관리하기 위한 자료구조이다.
간단한 예시로 A = {1, 2, 3, 4, 5} 에 대해서 B = {1, 3}, C = {2, 4, 5} 라면 B와 C가 A의 disjoint set이 된다. 이 disjoint set을 활용하여 문제를 해결하는 방식이 이번에 알아볼 disjoint set 알고리즘이 되겠다. 집합을 서로 중복되지 않는 부분집합들로 분할하기 위한 핵심 연산이 union 과 find 연산이기에 Union-Find 알고리즘이라고도 부른다.
1. 집합의 표현
Disjoint Set을 표현하는 데는 여러 방법이 있지만 가장 간단하고 효율적인 트리를 사용하는 방법을 알아본다.
disjoint_set = [-1] * (node_num + 1)
먼저 각 노드의 부모 노드를 관리하는 리스트(배열)를 선언한다. 자기자신이 루트노드인 경우에는 해당 트리에 속한 자신을 포함한 노드의 갯수에 음수를 취한 값을 저장한다. 아직 분할이 이뤄지지 않은 처음에는 모두가 자기 자신만이 존재하는 트리의 루트노드이기에 초기값을 -1로 한다.
def union(disjoint_set, A, B):
A = find(disjoint_set, A)
B = find(disjoint_set, B)
if A != B:
if disjoint_set[A] < disjoint_set[B]:
disjoint_set[A] += disjoint_set[B]
disjoint_set[B] = A
else:
disjoint_set[B] += disjoint_set[A]
disjoint_set[A] = B
다음은 union 연산의 구현이다. A 와 B에 대해 find 연산을 통해 A와 B가 속한 부분집합의 루트 노드를 구한다.
만약 두 루트노드가 서로 같다면 A와 B는 이미 같은 집합에 속한 원소이기에 더이상 union 연산을 진행할 필요가 없다.
서로 다르다면 한쪽의 루트노드가 다른 한쪽의 자식 노드로 추가되어야 한다. 이 때, 보다 깊이가 큰 노드에 작은 노드가 매달려야 트리의 깊이를 최소화할 수 있기에, disjoint_set 값이 보다 작은쪽에 큰 쪽을 추가한다. (루트노드의 disjoint_set 값은 트리의 크기에 음수를 취한것이기 때문에 그 값이 더 작다는 것은 트리의 크기(깊이)가 더 크다는 것을 의미한다.)
def find(disjoint_set, x):
if disjoint_set[x] < 0:
return x
disjoint_set[x] = find(disjoint_set, disjoint_set[x])
return disjoint_set[x]
마지막으로 find 연산의 구현이다. 만약 x의 disjoint_set 값이 음수라면 자신이 루트노드이기 때문에 스스로를 반환한다.
음수가 아니라면 자신의 부모노드에 대해 find 함수를 재귀호출하여 루트노드를 구하여 반환한다. 이 때, disjoint_set[x] 를 루트노드로 갱신하여 트리의 깊이를 낮추는 것으로 find 연산에 걸리는 시간을 단축할 수 있다.
Disjoint Set 은 여러개의 노드간의 연결관계가 주어질 때 그 노드들이 같은 네트워크에 속하는지 알아내는 종류의 문제에 유용하다. 아래는 그와같은 유형의 문제(https://www.acmicpc.net/problem/4195)를 Disjoint Set을 사용하여 해결한 코드이다.
import sys
input = sys.stdin.readline
# 4195 친구 네트워크
# 친구관계 (연결관계)가 주어질 때마다 친구네트워크의 크기, 즉 친구가 된 둘이 속한 친구집합의 크기를 구하는 문제
def sol4195():
answer = []
for t in range(int(input())):
# 친구의 이름에 매칭되는 숫자를 저장
friend = {}
u = []
res = []
for _ in range(int(input())):
# 서로 친구가 된 두 명의 이름
a, b = input().split()
# 처음 등장한 이름이 있다면 번호를 매겨주고 새로운 부분집합으로 추가한다
if a not in friend:
friend[a] = len(friend)
u.append(-1)
if b not in friend:
friend[b] = len(friend)
u.append(-1)
# 두 사람이 속한 친구 네트워크를 합치기 위해 union 연산 실행
a = friend[a]
b = friend[b]
# 두 사람이 속한 친구 네트워크의 크기를 구함
res.append(-union(u, a, b))
answer.append('\n'.join(map(str, res)))
return '\n'.join(answer)
# union 연산
def union(u, a, b):
a = find(u, a)
b = find(u, b)
if a != b:
if u[a] < u[b]:
u[a] += u[b]
u[b] = a
return u[a]
else:
u[b] += u[a]
u[a] = b
return u[b]
return u[a]
# find 연산
def find(u, x):
if u[x] < 0:
return x
u[x] = find(u, u[x])
return u[x]
또한 노드간의 연결이 계속해서 주어질 때, 사이클이 발생하는 연결인지의 탐지 또한 가능하다. 중복되지 않는 두 부분집합간의 union 에서는 사이클이 발생하지 않지만 이미 같은 부분집합에 속한 두 노드가 연결될 경우 둘 사이의 경로가 하나 더 생기기 떄문에 사이클이 발생한다. 아래는 사이클의 발생을 탐지하는 문제(https://www.acmicpc.net/problem/20040)를 Disjoint Set으로 해결한 코드이다.
import sys
input = sys.stdin.readline
# 20040 사이클 게임
# 주어진 점들을 선분으로 이어나가며 사이클이 생기는 차례를 구하는 문제
# 사이클이 생기려면 이미 같은 집합(네트워크)에 속해있는 점끼리 이어져야 한다.
# 이를 이용하면 유니온 파인드를 통해 간단히 해결 가능하다.
def sol20040():
n, m = map(int, input().split())
u = [-1] * (n + 1)
check = 0
for turn in range(1, m + 1):
s, e = map(int, input().split())
if not union(u, s, e):
check = turn
break
return check
def union(u, a, b):
a = find(u, a)
b = find(u, b)
if a == b:
return False
if u[a] < u[b]:
u[a] += u[b]
u[b] = a
else:
u[b] += u[a]
u[a] = b
return True
def find(u, x):
if u[x] < 0:
return x
u[x] = find(u, u[x])
return u[x]
단순히 union 연산이 발생했을 때, 인자로 받은 a, b가 같은 부분집합에 속해있는 경우 그 연산이 발생한 차례를 출력하면 해결 가능한 문제이다.
두 문제 모두 Disjoint Set에 대해 알고있다면 굉장히 쉽게 해결 가능하지만 몰랐다면 상당히 고전했을 문제들이었다. 충분히 실전에서 볼 수 있을법한 유형의 문제들이기에 Disjoint Set을 확실히 이해해둘 필요가 있겠다.
'코딩테스트 > 알고리즘' 카테고리의 다른 글
#7 투 포인터(Two Pointers) (0) | 2021.08.02 |
---|---|
#6 LCS 알고리즘 (0) | 2021.07.25 |
#4 최장 증가 부분수열(Longest Increasing Subsequence) 알고리즘 (0) | 2021.07.18 |
#3 최단거리 알고리즘 - Floyd-Warshall Algorithm (0) | 2021.07.14 |
#2 최단거리 알고리즘 - Bellman-Ford Algorithm (0) | 2021.07.13 |