구간 합 구하기
Segment Tree를 이용해 구간 합을 구하는 문제다.
Segment Tree의 리프 노드(자식 노드가 없는 노드)는 주어진 배열의 숫자를 저장하고, 나머지 노드는 자식 노드들의 합을 저장한다.
Segment Tree 구현
‘a = [1, 2, 3, 4, 5, 6]’로 Segment Tree를 구성하면 아래와 같다.
- 가장 위에 위치한 1번 노드는 0번째 ~ 5번째 원소의 합을 저장한다.
- 자식 노드인 2번 노드와 3번 노드에는 0번째 ~ (0+5)/2번째 원소의 합과 (0+5)/2+1번째 ~ 5번째 원소의 합을 저장한다.
- 이렇게 합을 저장하며 내려가다가 더 이상 나뉠 수 없을 때 해당 노드에 a[i]값을 저장하게 된다.
- 노드 번호는 위에서 부터, 왼쪽에서 오른쪽으로 1씩 증가하며 매겨진다.
- 노드 번호가 X일때, 자식 노드의 번호는 ‘X2’와 ‘X2+1’이다.
리스트에 해당 노드 번호 정보를 이용해 값을 저장한다.
‘tree’리스트를 만들어 여기에 각 노드의 값들을 저장했다. 이때, ‘원소의 개수*4’ 만큼의 길이로 tree 리스트를 만들었다.
f(x)는 x개의 원소를 세그먼트 트리로 구현할 때 필요한 노드의 개수를 의미한다. 위의 식을 통해 f(x)가 항상 x*4 이하임을 알 수 있다.
n,m,k=map(int,read().split())
a=list(int(read()) for _ in range(n))
tree = [0]*(4*n)
#세그먼트 트리 구성
def tree_init(start, end, node):
if start == end:
tree[node] = a[start]
return tree[node]
mid = (start + end)//2
#특정 노드의 하위 노드 번호는 node*2와 node*2+1이다.
tree[node] = tree_init(start, mid, node*2) + tree_init(mid + 1, end, node*2+1)
return tree[node]
tree_init(0, n-1, 1)
Segment Tree 구간합 구하기
‘a = [1, 2, 3, 4, 5, 6]’에서 0번째 원소부터 4번째 원소까지의 구간합을 구해보자.
- 1번째 노드부터 탐색을 시작한다.
- 0 - 5는 0 - 4와 겹치긴 하지만 범위를 벗어남으로 자식노드로 이동해 탐색을 이어간다.
- 0 - 2는 0 - 4의 범위에 포함된다. 따라서 해당 노드의 값을 더해준다.
- 3 - 5는 0 - 4와 겹치긴 하지만 범위를 벗어남으로 자식노드로 이동해 탐색을 이어간다.
- 3 - 4는 0 - 4의 범위에 포함된다. 따라서 해당 노드의 값을 더해준다.
- 5 - 5는 0 - 4의 범위를 벗어남으로 0을 return해 아무것도 더해주지 않는다.
#구간합 구하는 함수
def tree_sum(start, end, node, left, right):
#범위를 벗어날 때
if left > end or right < start:
return 0
#범위가 left와 right사이에 있을 때는 해당 노드의 값을 return해 더해준다.
if left <= start and end <= right:
return tree[node]
mid = (start + end)//2
#범위가 겹치긴 하지만 포함되지는 않는 경우 하위 노드를 탐색한다.
return tree_sum(start, mid, node*2, left, right) + tree_sum(mid + 1, end, node*2+1, left, right)
Segment Tree 원소의 값 변경하기
‘a = [1, 2, 3, 4, 5, 6]’에서 4번째 원소의 값을 5에서 7로 변경해보자.
- 기존 4번째 원소의 값과 7의 차이는 +2이다. 따라서 4를 포함하는 모든 노드의 값에 2를 더해주면 된다.
- 값을 업데이트 하는 과정 역시 1번 노드부터 탐색해보자.
- 1번 노드는 당연히 4번째 원소를 포함한다. 2를 더해주고 자식 노드로 이동한다.
- 2번 노드는 4번째 원소를 포함하지 않는다. return으로 탐색을 종료한다.
- 3번 노드는 4번째 원소를 포함한다. 2를 더해주고 자식노드로 이동하자.
- 위와 같은 방법으로 리프 노드에 도착할 때 까지 탐색을 진행한다.
#값을 업데이트해주는 함수
def tree_update(start,end,node,index,diff):
#범위가 index를 포함하지 않으면 별도로 수정할 필요가 없다.
if index < start or index > end:
return
#위의 if문에서 return되지 않은 경우 index를 포함하고 있다는 의미이다.
tree[node] += diff
#start와 end가 같아야 하위 노드까지 도착했다는 뜻이다.
#start와 end가 같을 때 까지 하위 노드로 이동해 해당노드에 diff를 더해준다.
if start != end:
mid = (start + end)//2
tree_update(start, mid ,node*2, index, diff)
tree_update(mid+1, end ,node*2+1, index, diff)
CODE
import sys
read=sys.stdin.readline
#세그먼트 트리 구성
def tree_init(start, end, node):
if start == end:
tree[node] = a[start]
return tree[node]
mid = (start + end)//2
#특정 노드의 하위 노드 번호는 node*2와 node*2+1이다.
tree[node] = tree_init(start, mid, node*2) + tree_init(mid + 1, end, node*2+1)
return tree[node]
#구간합 구하는 함수
def tree_sum(start, end, node, left, right):
#범위를 벗어날 때
if left > end or right < start:
return 0
#범위가 left와 right사이에 있을 때는 해당 노드의 값을 return해 더해준다.
if left <= start and end <= right:
return tree[node]
mid = (start + end)//2
#범위가 겹치긴 하지만 포함되지는 않는 경우 하위 노드를 탐색한다.
return tree_sum(start, mid, node*2, left, right) + tree_sum(mid + 1, end, node*2+1, left, right)
#값을 업데이트해주는 함수
def tree_update(start,end,node,index,diff):
#범위가 index를 포함하지 않으면 별도로 수정할 필요가 없다.
if index < start or index > end:
return
#위의 if문에서 return되지 않은 경우 index를 포함하고 있다는 의미이다.
tree[node] += diff
#start와 end가 같아야 하위 노드까지 도착했다는 뜻이다.
#start와 end가 같을 때 까지 하위 노드로 이동해 해당노드에 diff를 더해준다.
if start != end:
mid = (start + end)//2
tree_update(start, mid ,node*2, index, diff)
tree_update(mid+1, end ,node*2+1, index, diff)
n,m,k=map(int,read().split())
a=list(int(read()) for _ in range(n))
tree = [0]*(4*n)
tree_init(0, n-1, 1)
for _ in range(m+k):
x,y,z=map(int,read().split())
if x == 1:
y-=1
diff = z - a[y]
a[y] = z
tree_update(0, n-1, 1, y, diff)
else:
print(tree_sum(0, n-1, 1, y-1, z-1))
틀린 부분이 있을 수 있습니다. 피드백 주시면 고치도록 하겠습니다!
감사합니다.👍
[꼭 다시 풀어보기]