행렬제곱-비트연산자

문제 분석

9월 26일 한달 전에 푼 행렬 제곱 문제다. DP를 활용해 풀었었는데, 블로그 포스팅을 하다가 비트연산자를 활용한 방법도 발견해서 함께 포스팅 했었다.

그래서 이번에는 이 문제를 비트 연산자를 활용해 접근하고, 다시 한번 내가 직접 풀어보는 시간을 가졌다.

사실 재귀함수&DP를 사용해 푼 방법과 비트연산자를 활용한 방법은 근본적으로는 비슷하다.

  • 비트연산자

이번에 사용한 비트연산자를 활용한 방법에 대해 알아보자. 먼저, 입력으로 받은 곱셈횟수 b를 2진수로 변환한다. 예를 들어서 14라는 수가 주어지면, ‘1110’으로 바뀌는 것이다. 이제, 각 자리수가 0인지 1인지 파악해주면 된다. 14의 경우에는 ‘1110’임으로 행렬^8 X 행렬^4 X 행렬^2 가 되는 것이다.

for i in range(len(bin(b))-1):
    if (1<<i)&b:
        answer=mul(answer,matrix)

    matrix=mul(matrix,matrix)

matrix는 i가 1씩 커짐에 따라 제곱시켜주고, 위와 같이 ‘1«i’ 와 b를 비교해 b의 해당 자리 숫자가 1이라면 answer에 matrix를 곱해준다.

  • DP

재귀&DP를 활용한 방법에서는 아래(코드의 일부분을 가져왔다.)와 같이 multiple함수를 실행시키면, m번 연산한 값이 저장되어 있지 않을 때, m을 2분해서 다시 재귀함수를 호출한다.

def multiple(n,m):
    # 행렬을 m번 연산한 값이 이미 저장되어있으면 바로 return
    if m in save:
        return save[m]
    else:
        #저장되어있지 않을 경우 더 작은 값에 대해 재귀 호출을 한다.
        lista=multiple(n,m//2)
        listb=multiple(n,m-m//2)


CODE - 비트연산자 사용

import sys
read = sys.stdin.readline

n, b = map(int, read().split())
matrix = list(list(map(int, read().split())) for _ in range(n))
# 행렬 곱셈 함수

def mul(a, b):
    new = list(list(0 for _ in range(n)) for _ in range(n))
    for i in range(n):
        for j in range(n):
            for k in range(n):
                new[i][j] += a[i][k]*b[k][j]
            new[i][j] %= 1000
    return new

# I 행렬을 기본값으로 세팅
answer = list(list(0 for _ in range(n)) for _ in range(n))
for i in range(n):
    answer[i][i] = 1
# 포문을 한번 돌 때마다 matrix는 제곱이 되고,
# 이진수로 바꿨을 때, 1이 있는 지점마다 matrix를 answer에 곱해준다.
for i in range(len(bin(b))-1):
    tmp = 1 << i
    if tmp & b:
        answer = mul(answer, matrix)

    matrix = mul(matrix, matrix)

for a in answer:
    print(*a)


import sys
n,m=map(int,sys.stdin.readline().split())
matrix=[list(map(int,sys.stdin.readline().split())) for _ in range(n)]

for i in range(n):
    for j in range(n):
        matrix[i][j]= matrix[i][j]%1000

#행렬을 n 번 곱한 값을 save[n]에 저장
save=dict()
save[1]=matrix

def multiple(n,m):
    # 행렬을 m번 연산한 값이 이미 저장되어있으면 바로 return
    if m in save:
        return save[m]
    else:
        #저장되어있지 않을 경우 더 작은 값에 대해 재귀 호출을 한다.
        lista=multiple(n,m//2)
        listb=multiple(n,m-m//2)
        matrix=list([0 for _ in range(n)] for _ in range(n))

        for i in range(n):
            for j in range(n):
                matrix[i][j]=0
                for x in range(n):
                    matrix[i][j]+=lista[i][x]*listb[x][j]
                matrix[i][j]%=1000

        save[m]=matrix
        return save[m]
#출력
for x in multiple(n,m):
    for i in x:
        print(i,end=' ')
    print()


틀린 부분이 있을 수 있습니다. 피드백 주시면 고치도록 하겠습니다! 감사합니다.👍

[꼭 다시 풀어보기]