행렬 곱셈 순서

혹시 시간 초과로 고민하시는 분들은 pypy3로 제출해보시기 바랍니다. 시간 초과 때문에 고민고민 하다가 결국 검색해봤는데, python3로는 통과하신 분이 거의 없는듯 했습니다. (다시 확인해보니 단 한명도 없네요…)

문제

DP 유형들 중 Matrix Chain Multiplication 관련 문제다. 행렬을 곱할 때 곱하는 순서에 따라 연산 횟수가 달라진다.

예제 입력을 통해 이해해보자.

5 3
3 2
2 6

첫번째, 두번째 행렬부터 계산하면 532 + 526 = 90 두번째, 세번째 행렬부터 계산하면 326 + 536 = 126 90과 126으로 연산 횟수가 다름을 알 수 있다.

DP 점화식을 만들어 보자. 연산 횟수를 저장할 dp는 n*n 이차원 리스트이고, 행렬이 들어있는 리스트는 ‘m’이다. 0번째 행렬 부터 4번째 행렬 까지 있다고 가정하고, dp[a][b]에는 a번째 행렬부터 b번째 행렬까지의 최소 연산 횟수를 입력한다. (a와 b가 같으면 0이 입력된다.)

0번째 행렬 부터 4번째 행렬 까지 곱할때 연산의 최솟값은

  • dp[0][0] + dp[1][4] + m[0][0] _ m[0][1] _ m[4][1]
  • dp[0][1] + dp[2][4] + m[0][0] _ m[1][1] _ m[4][1]
  • dp[0][2] + dp[3][4] + m[0][0] _ m[2][1] _ m[4][1]
  • dp[0][3] + dp[4][4] + m[0][0] _ m[3][1] _ m[4][1]

다음 값들 중 최솟값이다.

a번째 행렬 부터 b번째 행렬 까지로 일반화 시키면 아래와 같다.

  • dp[a][0] + dp[1][b] + m[a][0] _ m[1][1] _ m[b][1]
  • dp[a][1] + dp[2][b] + m[a][0] _ m[2][1] _ m[b][1]
  • dp[a][2] ….
  • ~
  • dp[a][b-3] ….
  • dp[a][b-2] + dp[b-1][4] + m[a][0] _ m[b-2][1] _ m[b][1]
  • dp[a][b-1] + dp[b][b] + m[a][0] _ m[b-1][1] _ m[b][1]


CODE

import sys
sys.stdin = open('input.txt','rt')

n=int(sys.stdin.readline())
m=list(list(map(int,sys.stdin.readline().split())) for _ in range(n))

dp=list(list(0 for _ in range(n)) for _ in range(n))

for i in range(1,n):
    for j in range(n-i):
        dp[j][j+i]=2147000000
        for k in range(i):
            if dp[j][j+i]<dp[j][j+k]+dp[j+k+1][i+j]: continue
            tmp=dp[j][j+k]+dp[j+k+1][i+j]+m[j][0]*m[j+k][1]*m[i+j][1]
            if tmp < dp[j][j+i]:
                dp[j][j+i]=tmp
# for x in dp:
#     print(x) #이 코드를 통해 dp를 확인해보자
print(dp[0][-1])

일주일 정도 DP를 공부하며 슬슬 감을 잡았다고 생각했는데 오산이었다. DP로 푸는 문제라는 것을 알고 접근해도, 점화식을 생각하고 코드로 구현하는 것이 만만치 않다고 느꼈다.


DP 대표 유형

[다시 풀어보기]