IT_Study/C++

[C++] Prefix Sum, Segment Tree 구현 및 상세 설명

__Vivacé__ 2023. 3. 28. 21:36

int query(int left, int right) {
        // 1st arg : arr 부분합의 처음 index
        // 2nd arg : arr 부분합의 마지막 index
        // 3rd arg : 구간합이 저장되어 있는 Tree의 index
        // 4th arg : 3rd arg가 가리키는 구간합의 처음 index
        // 5th arg : 3rd arg가 가리키는 구간합의 마지막 index
        return queryRec(left, right, 1, 0, N - 1);
}

int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {

    // 구간합을 벗어나면, 0을 return
    if (right < nodeLeft || nodeRight < left)
        return 0;

    // nodeLeft와 nodeRight가 모두 구하고자 하는 구간 안에 있으면, 값을 더해준다
    if (left <= nodeLeft && nodeRight <= right)
        return tree[node];

    int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
    
    return merge(queryRec(left, right, node * 2, nodeLeft, mid),
        queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
}

Prefix Sum

배열에서 부분합을 미리 계산해 놓은 배열

Prefix Sum 배열은 해당 인덱스의 합을 의미

장점 : 배열의 특정 범위 내 원소들의 합을 빠르게 계산할 수 있음

※ 7번 index부터 14번 index까지의 합 구하기
    - Array : 7번부터 14번까지의 합을 for loop을 통해 계산
    - Prefix Sum Array : 14번째 요소 - 7번째 요소

시간복잡도

    - 배열을 한번 순회하여 Prefix Sum 배열을 계산하는 단계 : O(N)

    - Prefix Sum 배열을 이용하여 부분 합을 구하는 단계 : O(1)

 

구현

#include <iostream>

using namespace std;

int main() {
    int arr[5] = {1, 2, 3, 4, 5};
    int prefix_sum[5] = {0, };

    prefix_sum[0] = arr[0];

    for (int i = 1; i < 5; i++) {
        prefix_sum[i] = prefix_sum[i - 1] + arr[i];
    }

    return 0;
}

 


Segment Tree

배열의 구간 합을 빠르게 구하기 위한 자료구조

총 3가지 메서드 (Build, Query, Update)가 존재

    - Build : Array를 이용해 Segment Tree를 구성하는 메서드

    - Query : 특정 구간부터 특정 구간까지의 합을 가져오는 메서드

    - Update : Array의 특정 값을 Update시켜 Segment Tree를 갱신하는 메서드

 


Build 과정

Segment Tree를 초기화하는 과정을 살펴보자

struct SegmentTree {
    int N;              // size
    vector<int> tree;   // segment tree

    void build(const int arr[], int size) {
        N = size;
        tree.resize(N * 4);
        
        // (대상 배열, Tree idx, 부분합 배열 왼쪽 idx, 부분합 배열 오른쪽 idx)
        buildRec(arr, 1, 0, N - 1);
    }
    
. . .

 

Array를 Segment Tree에 입력 시, 2N-1 크기의 Vector 생성

이후 Recursive하게 Leaf Node에 닿을 때까지 계속 내려감


 

Query 과정

Array의 index 1 ~ index 4 까지의 합을 구해보자

int query(int left, int right) {
        // 1st arg : arr 부분합의 처음 index
        // 2nd arg : arr 부분합의 마지막 index
        // 3rd arg : 구간합이 저장되어 있는 Tree의 index
        // 4th arg : 3rd arg가 가리키는 구간합의 처음 index
        // 5th arg : 3rd arg가 가리키는 구간합의 마지막 index
        return queryRec(left, right, 1, 0, N - 1);
}

int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {

    // 구하고자 하는 구간과 겹치는 부분이 하나도 없으면, 0을 return
    if (right < nodeLeft || nodeRight < left)
        return 0;

    // nodeLeft와 nodeRight가 모두 구하고자 하는 구간 안에 있으면, 값을 더해준다
    if (left <= nodeLeft && nodeRight <= right)
        return tree[node];

    int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
    
    return merge(queryRec(left, right, node * 2, nodeLeft, mid),
        queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
}

탐색 시작 : Node = 1일 때

해당 구간(0 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과

3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색

 Node = 3일 때

해당 구간(3 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과

3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색

 

Node = 7일 때

해당 구간(5 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 없으므로, 0 반환

Node = 6일 때

해당 구간(3 ~ 4)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 tree[6] 값 반환

Node = 2일 때

해당 구간(0 ~ 2)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과

3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색

Node = 5일 때

해당 구간(2 ~ 2)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 return인 tree[5] 값 반환

 

Node = 4일 때

해당 구간(0 ~ 1)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과

3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색

Node = 9일 때

해당 구간(0 ~ 1)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과

그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 return인 tree[9] 값 반환

Node = 8일 때

해당 구간(0 ~ 0)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 없으므로, 0 반환

Traverse를 마쳤으니, merge 함수에 의해 모든 함수의 결과값이 합쳐져서 21을 반환

 


Update 과정

// Array의 index를 newValue로 update했다!

int update(int index, int newValue) {
    return updateRec(index, newValue, 1, 0, N - 1);
}

int updateRec(int index, int newValue, int node, int nodeLeft, int nodeRight) {
    if (index < nodeLeft || nodeRight < index)
        return tree[node];

    if (nodeLeft == nodeRight)
        return tree[node] = newValue;

    int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
    int leftVal = updateRec(index, newValue, node * 2, nodeLeft, mid);
    int rightVal = updateRec(index, newValue, node * 2 + 1, mid + 1, nodeRight);
    return tree[node] = merge(leftVal, rightVal);
}

 

Index가 1인 부분의 값을 4로 변경 시의 과정을 살펴보자

Update 시작 : Node = 1일 때

저장된 구간합 부분(0 ~ 5)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과

단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과

따라서 좌측 구간합을 구하는 재귀 과정을 거친다.

Node = 2일 때

저장된 구간합 부분(0 ~ 2)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과

단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과

따라서 좌측 구간합을 구하는 재귀 과정을 거친다.

Node = 4일 때

저장된 구간합 부분(0 ~ 1)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과

단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과

따라서 좌측 구간합을 구하는 재귀 과정을 거친다.

Node = 8일 때

저장된 구간합 부분(0 ~ 0)이 변경된 index(1)를 포함하지 않으므로, tree[node] 반환

Node = 9일 때

저장된 구간합 부분(1 ~ 1)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과

단일값(NodeLeft == NodeRight)을 가리키므로, tree[node] 값을 newValue로 update 후 반환

Node = 4일 때

좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환

Node = 5일 때

저장된 구간합 부분(2 ~ 2)이 변경된 index(1)를 포함하지 않으므로, tree[node] 반환

 

Node = 2일 때

좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환

Node = 3일 때

저장된 구간합 부분(3 ~ 5)이 변경된 index(1)를 포함하지 않으므로, tree[node] 값 반환

Node = 1일 때

좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환

 

Update가 끝난 모습

 


전체 코드 구현

#include <iostream>
#include <vector>

using namespace std;


struct SegmentTree {
    static const int DEFAULT_VALUE = 0;
    //static const int DEFAULT_VALUE = numeric_limits<int>::max(); // for min
    //static const int DEFAULT_VALUE = numeric_limits<int>::min(); // for max

    // merge operation
    int merge(int left, int right) {
        return left + right;        // sum
        //return min(left, right);  // min
        //return max(left, right);  // max
        //...
    }

    int N;              // size
    vector<int> tree;   // segment tree

    void build(const int arr[], int size) {
        N = size;
        tree.resize(N * 4);

        buildRec(arr, 1, 0, N - 1);
    }

    // inclusive
    int update(int index, int newValue) {
        return updateRec(index, newValue, 1, 0, N - 1);
    }

    // inclusive
    int update(int left, int right, int newValue) {
        return updateRec(left, right, newValue, 1, 0, N - 1);
    }

    // inclusive
    int query(int left, int right) {
        return queryRec(left, right, 1, 0, N - 1);
    }

    private:
        int buildRec(const int arr[], int node, int nodeLeft, int nodeRight) {
            if (nodeLeft == nodeRight)
                return tree[node] = arr[nodeLeft];

            int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
            int leftVal = buildRec(arr, node * 2, nodeLeft, mid);
            int rightVal = buildRec(arr, node * 2 + 1, mid + 1, nodeRight);
            return tree[node] = merge(leftVal, rightVal);
        }

        int updateRec(int index, int newValue, int node, int nodeLeft, int nodeRight) {
            if (index < nodeLeft || nodeRight < index)
                return tree[node];

            if (nodeLeft == nodeRight)
                return tree[node] = newValue;

            int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
            int leftVal = updateRec(index, newValue, node * 2, nodeLeft, mid);
            int rightVal = updateRec(index, newValue, node * 2 + 1, mid + 1, nodeRight);
            return tree[node] = merge(leftVal, rightVal);
        }

        int updateRec(int left, int right, int newValue, int node, int nodeLeft, int nodeRight) {
            if (right < nodeLeft || nodeRight < left)
                return tree[node];

            if (nodeLeft == nodeRight)
                return tree[node] = newValue;

            int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
            int leftVal = updateRec(left, right, newValue, node * 2, nodeLeft, mid);
            int rightVal = updateRec(left, right, newValue, node * 2 + 1, mid + 1, nodeRight);
            return tree[node] = merge(leftVal, rightVal);
        }

        int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {
            if (right < nodeLeft || nodeRight < left)
                return DEFAULT_VALUE;   // default value

            if (left <= nodeLeft && nodeRight <= right)
                return tree[node];

            int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
            return merge(queryRec(left, right, node * 2, nodeLeft, mid),
                queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
        }
};


int main() {
    int arr[] = { 1, 3, 5, 7, 9, 11 };
    int size = sizeof(arr) / sizeof(int);

    SegmentTree st;
    st.build(arr, size);

    st.update(1, 4); // arr[1] : 3 -> 4

    cout << st.query(1, 4) << endl; // 4 + 5 + 7 + 9 = 25
    cout << st.query(0, 2) << endl; // 1 + 4 + 5 = 10


    return 0;
}

 

 


시간복잡도

1. build() : O(N)

    - 배열의 원소를 한 번씩 처리

    - 총 2N-1개의 Node를 처리하므로, O(N)

 

2. query() : O(logN)

    - 자세히 파고들면 끝이 없으니 Level (최대 logN) 마다 상수 시간의 node 탐색을 진행한다고 생각할 것

 

3. Update() : O(logN)

    - Query의 시간복잡도와 비슷한 이유

 


값을 변경하지 않는 경우
(immutable)
 값을 변경하는 경우
(mutable)
Prefix Sum

Sparse Table
Segment Tree

Sqrt - Decomposition