본문 바로가기
알고리즘 PS (백준)/🐍 Python (파이썬)

[백준 2887] 행성 터널(크루스칼 알고리즘, Kruskal, 최소신장트리) - 파이썬 Python

by 코딩하는 동현😎 2024. 5. 26.

행성 터널

문제

때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.

행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.

민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.


입력

첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -10^9보다 크거나 같고, 10^9보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다. 


출력

첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.


예제 입력 1

5
11 -15 -15
14 -5 -15
-1 -1 -5
10 -4 -1
19 -4 19

예제 출력 1

4

풀이

최소 신장 트리는 모든 노드가 연결되도록 간선을 그은 트리의 가중치의 합이 가장 작은 것을 말합니다.

가중치의 값이 최소가 돼야되기 때문에 사이클은 없어야 합니다.

간선으로 모든 행성을 연결하면서, 비용은 최소로 돼야되니까 최소 신장 트리를 구하는 크루스칼 알고리즘이 알맞는 알고리즘입니다.

 

이 문제에선 행성간 간선을 주어지지 않으므로 직접 간선을 만들어야 합니다.

N개의 행성 중 한 행성에서 다른 행성의 간선을 만들어보게 되면 N=100,000이므로 nC2 = (100,000 * 99,999 / 2),

즉 O(N^2)의 공간 복잡도와 시간 복잡도가 걸리게 됩니다.

그러므로 모든 간선을 만들어보지 않고, 선택될 가능성이 있는 간선만 추려서 구성해야 됩니다.

 

비용에 대한 식이 min(|xA-xB|, |yA-yB|, |zA-zB|)이므로 x,y,z상관없이 가장 근거리에 있는 행성들간의 간선을 선택하는게 비용이 제일 적습니다.

그러므로 행성들을 각각 x,y,z좌표를 분리해서 정렬하고 좌표 기준으로 정렬된 행성들을 각각 바로 옆에 인접한 행성들을 가지고 간선을 구하는 것입니다.

 

x좌표 기준으로 (x좌표, 행성번호) 형식의 튜플로 xList안에 넣고 첫번째 원소를 기준으로 정렬을 한다음에, 간선리스트(edge)에 (|xA-xB|, A,  B) 형식으로 엣지리스트에 넣어줍니다. (A와 B는 각각 xList의 i순위의 행성번호와 i+1순위의 행성번호 입니다)

 

x 좌표 기준으로 정렬해서 N개의 행성들끼리 간선을 만들면 N-1개가 나오고 y,z좌표에도 이 과정을 반복하면 총 간선은 3(N-1)개이므로 O(3N), 즉 O(N)의 공간과 시간 복잡도가 나오게 됩니다.


크루스칼 알고리즘

크루스칼 알고리즘은 최소 신장 부분 트리를 찾는 알고리즘으로 가중치가 가장 낮은 엣지를 순서대로 고르지만, 사이클이 안생기도록 고르는 알고리즘입니다.

변의 개수를 E, 꼭짓점의 개수 V라고 하면 이 알고리즘은O(E\log V)의 시간복잡도를 가집니다.

 

1. 엣지리스트 초기화 하기

엣지들을 가중치가 낮은것 순서대로 뽑을 것이므로 리스트에다가 넣고, 나중에 정렬해줍니다.

2. 유니온 파인드 알고리즘 적용시키기

유니온 파인드는 쉽게 말해서 집합을 구현하는 알고리즘입니다. 아래 코드를 보시면 겉핥기로 이해되실 것입니다 (크루스칼 포스트이므로 상세하게 설명은 안하겠습니다.)

n, m = map(int, sys.stdin.readline().rstrip().split())

parent = [i for i in range(n+1)]

def find(x):
    if parent[x] == x:
        return x
    parent[x] = find(parent[x]) # get_parent 거슬러 올라가면서 parent[x] 값도 갱신
    return parent[x]

def union(a, b):
    a = find(a)
    b = find(b)
    if a < b: # 작은 쪽을 부모로 통일
        parent[b] = a
    else:
        parent[a] = b

 

엣지들이 있을때 가중치가 가장 낮은 엣지를 우선적으로 선택해서 차례대로 연결합니다.

만약에 연결하면 사이클이 생기면, 건너뛰고 다음 엣지부터 검토합니다.

사이클이 생기는 지 여부는 유니온 파인드 알고리즘을 통해 판별할 수 있습니다.


 

예를 들면 가중치가 9인 빨간색 엣지의 양끝 노드를 보면 B,D인데 둘다 같은 유니온(집합)에 속해있기 때문에 사이클이 생기므로 건너뛰고 다음 엣지를 살펴봅니다.


N이 노드의 갯수라고 할때, 탐색을 전부완료하면 N-1개의 엣지가 있어야되고, 모든 노드들이 한 유니온으로 연결되어있어야합니다.

 

엣지리스트를 내림차순으로 정렬하고 하나씩 뽑아서 유니온이 아닌 노드를 연결하는 엣지마다 적용 시켜주면 됩니다.

edges.sort()
i = 0
weight = 0
while cnt<n-1:
    w,a,b = edges[i][0] , edges[i][1], edges[i][2]
    if find(a) != find(b):
        cnt+=1
        weight += w
        union(a,b)
    i+=1
print(weight)

파이썬 코드 

import sys
input = sys.stdin.readline

n = int(input())
# 각 노드의 각 x,y,z좌표들을 따로 분류한 리스트
# (나중에 정렬해서 인접한 좌표끼리만 간선을 만들 예정)
xList = []
yList = []
zList = []
for i in range(n):
    x,y,z = map(int , input().split())
    xList.append((x, i)) # (x좌표, 노드번호)
    yList.append((y, i)) # (y좌표, 노드번호)
    zList.append((z, i)) # (z좌표, 노드번호)

edges = []
xList.sort()
yList.sort()
zList.sort()
for i in range(n-1):
    # 한좌표로 정렬 했을때 바로 옆에 있는 행성들끼리만 연결해서 간선 만듦
    # 0,1 1,2 ... n-1,n 짝짓기
    w = abs(xList[i][0] - xList[i+1][0]) # 가중치(거리)
    edges.append((w,xList[i][1], xList[i+1][1]))
    w = abs(yList[i][0] - yList[i+1][0]) # 가중치(거리)
    edges.append((w,yList[i][1], yList[i+1][1]))
    w = abs(zList[i][0] - zList[i+1][0]) # 가중치(거리)
    edges.append((w,zList[i][1], zList[i+1][1]))


edges.sort()
parent = [i for i in range(n)]

def find(x):
    if parent[x] == x:
        return x
    parent[x] = find(parent[x])
    return parent[x]

def union(a, b):
    a = find(a)
    b = find(b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def kruskal():
    i=0
    cnt=0
    distance=0
    while cnt < n-1:
        w,a,b = edges[i]
        i+=1
        if find(a) == find(b):
            continue
        cnt+=1
        distance+=w
        union(a,b)
    return distance

print(kruskal())

 

반응형

댓글