최단경로
다익스트라 알고리즘을 이용해 최단거리를 구하는 문제이다. 어제 처음으로 해설을 보고 풀어본 후 오늘 바로 다시 풀어봤는데도 푸는데 꽤 힘들었다.
컨셉 자체는 바로 이해할 수 있었지만 이를 최적화 시키는데 아이디어들이 꽤 필요한 듯 했다.
문제 분석
가중치가 있는 간선들을 입력받은 후, 시작점에서 부터 각 점까지의 최단경로를 구하는 문제이다. 만약 갈 수 없다면 INF를 출력하면 된다.
가중치가 있기 때문에 당연히 단순 BFS로는 못 풀고, 다익스트라 알고리즘을 사용해 풀어야 한다.
매 순간 출발지점으로 부터의 거리가 최소인 점으로 이동해야 한다.
다른 글 들에 더 자세하게 나와있지만, 입력 예제에 대해 한번 설명해 보겠다.
- 먼저 점들의 거리를 저장할 dp배열을 하나 만들자. dp[0,inf,inf,inf,inf] - 1번점에서 출발함으로 첫번째 값은 0, 나머지는 무한으로 세팅해 놓았다.
- 1에서 출발하면 2만큼 이동해 2번 점으로 3만큼 이동해 3번 점으로 이동 가능하다. dp[0,min(inf,0+2),min(inf,0+3),inf,inf] => dp[0,2,3,inf,inf]
- 1번 점은 이미 방문했음으로 나머지 점들 중 가장 최단거리인 두번째 점으로 이동한다.
- 2번 점에서는 4만큼 더 이동해 3번 점으로, 5만큼 더 이동해 4번점으로 갈 수 있다. dp[0,2,min(3,2+4),min(inf,2+5),inf] => dp[0,2,3,7,inf]
- 1번 점, 2번 점을 제외한 점들 중 최단거리인 세번째 점으로 이동해보자.
- 3번 점에서는 6만큼 더 이동해 4번 점으로 갈 수 있다. dp[0,2,3,min(7,3+6),inf] => dp[0,2,3,7,inf]
- 이제 방문하지 않은 점들 중 최솟값을 가지는 4번 점으로 이동하자. 그런데, 4번 점에서는 갈 수 있는 점이 없다. 추가로, 5번 점은 아직 inf로 도달하지 못한 채 남아있음으로, 계산은 여기서 종료된다.
dp를 차례로 출력하면 정답을 얻을 수 있다.
TIP
여기까지 유튜브로 해설을 듣고 구현을 했는데, 시간초과와 메모리 초과가 나를 괴롭혔다.
먼저, 간선들의 가중치를 저장할 때, (간선의 개수)V * V 행렬을 만들어 각각 저장을 했었다. 하지만, 이렇게 하면 가중치가 0이 아닌(간선이 있는)곳을 탐색하는데 추가로 시간이 소모됨으로 비효율적이다.
weight=list([] for _ in range(V))
for _ in range(e):
u,v,w=map(int,read().split())
weight[u-1].append([v-1,w])
위와 같은 방식으로 코드를 작성하면 시간 복잡도, 공간 복잡도를 함께 줄일 수 있다.
특정 점을 방문할 때 마다, dp를 탐색해 점들의 최단거리를 업데이트 해주고 아직 방문하지 않은 점들 중 최단거리에 있는 점을 다음으로 방문한다.
이때, 방문하지 않은 점들 중 최소값을 계속 뽑아줘야 하는데, 리스트의 정렬을 사용한다면 이 부분 또한 시간 복잡도를 굉장히 증가시킨다.
우리가 궁금한 것은 최소값임으로 heapq 자료형을 사용해 최소값을 계속 뽑을 수 있다. 이 방식을 이용해야 시간초과가 발생하지 않을 것이다.
CODE
import sys
read=sys.stdin.readline
inf=2147000000
V,e=map(int,read().split())
k=int(read())-1
# 간선의 가중치를 V*V배열을 만들어 저장했었다. -> 메모리 초과가 발생했다.
weight=list([] for _ in range(V))
for _ in range(e):
u,v,w=map(int,read().split())
weight[u-1].append([v-1,w])
dp=list(inf for _ in range(V))
dp[k]=0
import heapq
# heapq를 사용해야 시간복잡도를 줄일 수 있다.
# 방문하지 않은 점들 중 최단거리 인 점들을 뽑아야 하는데
# 이때, 최소값만 항상 뽑아주면 됨으로 최소힙이 적격하다.
save=[]
heapq.heappush(save,[0,k])
while save:
w,p=heapq.heappop(save)
for p_new,w_new in weight[p]:
if dp[p_new]>dp[p]+w_new:
dp[p_new]=dp[p]+w_new
heapq.heappush(save,[dp[p_new],p_new])
for i in dp:
if i==inf: print('INF')
else: print(i)
[다시 풀어보기]