본문 바로가기

CS Study/Algorithm(Coding Test)

[BOJ] 2042. 구간 합 구하기 (Python)

728x90
반응형
 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

이번에는 문제의 목적에 맞게 세그먼트리를 이용해서 구간 합을 구해보자!

이번 문제는 앞선 문제에서 구간 합을 업데이트 하는 기능까지 추가되었다.

구간 합 구하기 문제는 세그먼트 트리의 가장 대표적인 문제 유형이라고 볼 수 있다.

 

나는 안경잡이 개발자님의 블로그를 참고해서 세그먼트 트리 개념을 익혔다. 

원리와 코드가 매우 잘 설명이 되어 있다.

세그먼트 트리에 대해 처음 접하는 분들은 먼저 아래 블로그 글로 공부를 해보시는 것을 추천한다. 

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

 

 

나는 여기에다가 코드에 대한 부가설명을 조금 더 추가하겠다.

굉장히 자세한 설명임에도 처음에 코드의 원리가 약간 와닿지 않았기 때문에..!

 

먼저, 세그먼트 트리를 만드는 코드이다. 

 

def init(node, start, end) :
    if start == end : # 리프노드
        tree[node] = array[start]
        return tree[node]
    mid = (start + end) // 2
    tree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end)
    return tree[node]

 

여기서의 node는 배열에 저장되어 있는 숫자를 의미하는 것이 아니라 우리가 만들 "세그먼트 트리"에 대한 것임을 헷갈리면 안된다.

node, start, end에 들어가는 값의 의미는

세그먼트 트리의 node 번째 노드에 저장되는 숫자의 범위는 (배열의) start부터 (배열의) end까지 입니다. 

라는 뜻이다!

따라서 처음 init 함수를 호출 할 때에는

 

init(1, 0, n-1) # n은 배열의 길이

 

이렇게 호출하면 된다. 

첫 번째 노드(루트 노드)에는 배열의 0 ~ n-1번째 숫자들의 합이 들어가기 때문이다.

그럼 init 함수 내에서 재귀적으로 호출하며 부분 구간 합을 구하게 된다. 

코드 자체는 이해하기에 어렵지 않다.

 

 

다음은 이렇게 만든 세그먼트 트리의 구간 합을 구하는 코드이다.

 

# start, end : 노드에게 주어진 범위
# left, right : 내가 찾고자 하는 범위
def sum(node, start, end, left, right) :
    if start > right or end < left : # 내가 찾고자 하는 범위와 전혀 상관 없음
        return 0
    elif left <= start and end <= right : # 내가 찾고자 하는 범위가 노드에 완전히 포함이면 그대로 리턴 (자식 노드를 굳이 볼 필요가 없음)
        return tree[node]
    else :
        mid = (start + end) // 2
        return sum(2*node, start, mid, left, right) + sum(2*node+1, mid+1, end, left, right)

 

여기서도 주의할 부분은 start, end, left, right의 의미이다.

node, start, end, left, right을 파라미터로 갖는 sum의 의미는

세그먼트 트리의 node 번째에 있는 숫자는 start ~ end 까지의 부분합입니다. 그리고 내가 구하고자 하는 값은 left ~ right 범위의 부분합니다.

라는 뜻이다.

나는 처음에 세그먼트 트리 i 번째 노드에 있는 숫자가 어디부터 어디까지의 부분합인지 어떻게 알지?? 라는 부분이 직관적으로 와닿지 않았다. (나만 그럴수도..)

그리고 역시나 함수에서는 알 수 없다.

따라서 알려주는 것이다. node번째 노드에는 어디서부터 어디까지의 부분합이 들어있어! 라고!

그래서 내가 찾고자하는 범위(left ~ right) 값이 node번째 노드에 전혀 들어있지 않으면 더 이상 볼 필요가 없으니 return 0을 해준다.

만약 완전히 포함이 되어있다면 node 번째 노드의 값을 더해준다. (자식 노드를 봐봤자 node 번째 노드를 나눈 값들이니 굳이 볼 필요 X)

그리고 일부가 포함되어 있다면 이제 리프노드까지 봐서 필요한 부분합을 데려오면 된다.

 

호출 할 때에도 sum(1, 0, n-1, left, right)로 호출하면된다.

일단 루트노드에서부터 시작해서 한 칸씩 내려가면서 확인해야하기 때문에!

 

 

마지막으로 세그먼트 트리의 값을 업데이트 하는 부분이다.

 

def update(node, start, end, idx, diff) :
    if start > idx or end < idx : # 범위에 포함되지 않으면 교체할 필요 없음
        return 
    tree[node] += diff # 범위에 포함되니 수정
    if start != end : # 리프노드가 아니면 자식 노드도 수정
        mid = (start + end) // 2
        update(2 * node, start, mid, idx, diff)
        update(2* node + 1, mid + 1, end, idx, diff)    
    return

 

업데이트 코드는 부분합을 구하는 부분을 완벽히 이해했다면 어렵지 않을 것이다.

변경된 idx가 포함된 부분합에 해당하는 세그먼트 트리의 노드들만 업데이트를 해주면 된다.

업데이트 역시 기존 숫자로부터 얼마나 변경되었는지만 수정해주면 된다.

따라서 루트노드부터 차례로 변경 된 idx가 포함되어 있으면 diff(얼마나 변경되었는지)만큼 업데이트,

만약 포함되지 않은 범위면 그냥 패스해주면 된다.

 

 

2042번 문제의 전체 코드는 아래와 같다. 

당분간은 세그먼트 트리 문제만 풀면서 코드와 좀 더 친해져봐야겠다.

 

import sys
input = sys.stdin.readline

def init(node, start, end) :
    if start == end :
        tree[node] = array[start]
        return tree[node]
    mid = (start + end) // 2
    tree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end)
    return tree[node]

# start, end : 노드에게 주어진 범위
# left, right : 내가 찾고자 하는 범위
def sum(node, start, end, left, right) :
    if start > right or end < left : # 내가 찾고자 하는 범위와 전혀 상관 없음
        return 0
    elif left <= start and end <= right : # 내가 찾고자 하는 범위가 노드에 완전히 포함이면 그대로 리턴 (자식 노드를 굳이 볼 필요가 없음)
        return tree[node]
    else :
        mid = (start + end) // 2
        return sum(2*node, start, mid, left, right) + sum(2*node+1, mid+1, end, left, right)

def update(node, start, end, idx, diff) :
    if start > idx or end < idx : # 범위에 포함되지 않으면 교체할 필요 없음
        return 
    tree[node] += diff # 범위에 포함되니 수정
    if start != end : # 리프노드가 아니면 자식 노드도 수정
        mid = (start + end) // 2
        update(2 * node, start, mid, idx, diff)
        update(2* node + 1, mid + 1, end, idx, diff)    
    return

n, m, k = map(int, input().split()) # 배열의 숫자 개수, 수의 변경 횟수, 구간의 합 횟수
array = list()
tree = [0] * (4*n)
for _ in range(n) :
    array.append(int(input()))

init(1, 0, n-1) # 트리 초기화

for _ in range(m + k) :
    a, b, c = map(int, input().split())
    if a == 1 : # 구간 숫자 바꾸기
        diff = c - array[b-1]
        array[b-1] = c
        update(1, 0, n-1, b-1, diff)
    elif a == 2 : # 구간 합 구하기
        print(sum(1, 0, n-1, b-1, c-1))

 

 

 

728x90
반응형