int query(int left, int right) {
// 1st arg : arr 부분합의 처음 index
// 2nd arg : arr 부분합의 마지막 index
// 3rd arg : 구간합이 저장되어 있는 Tree의 index
// 4th arg : 3rd arg가 가리키는 구간합의 처음 index
// 5th arg : 3rd arg가 가리키는 구간합의 마지막 index
return queryRec(left, right, 1, 0, N - 1);
}
int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {
// 구간합을 벗어나면, 0을 return
if (right < nodeLeft || nodeRight < left)
return 0;
// nodeLeft와 nodeRight가 모두 구하고자 하는 구간 안에 있으면, 값을 더해준다
if (left <= nodeLeft && nodeRight <= right)
return tree[node];
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
return merge(queryRec(left, right, node * 2, nodeLeft, mid),
queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
}
Prefix Sum
배열에서 부분합을 미리 계산해 놓은 배열
장점 : 배열의 특정 범위 내 원소들의 합을 빠르게 계산할 수 있음
※ 7번 index부터 14번 index까지의 합 구하기
- Array : 7번부터 14번까지의 합을 for loop을 통해 계산
- Prefix Sum Array : 14번째 요소 - 7번째 요소
시간복잡도
- 배열을 한번 순회하여 Prefix Sum 배열을 계산하는 단계 : O(N)
- Prefix Sum 배열을 이용하여 부분 합을 구하는 단계 : O(1)
구현
#include <iostream>
using namespace std;
int main() {
int arr[5] = {1, 2, 3, 4, 5};
int prefix_sum[5] = {0, };
prefix_sum[0] = arr[0];
for (int i = 1; i < 5; i++) {
prefix_sum[i] = prefix_sum[i - 1] + arr[i];
}
return 0;
}
Segment Tree
배열의 구간 합을 빠르게 구하기 위한 자료구조
총 3가지 메서드 (Build, Query, Update)가 존재
- Build : Array를 이용해 Segment Tree를 구성하는 메서드
- Query : 특정 구간부터 특정 구간까지의 합을 가져오는 메서드
- Update : Array의 특정 값을 Update시켜 Segment Tree를 갱신하는 메서드
Build 과정
Segment Tree를 초기화하는 과정을 살펴보자
struct SegmentTree {
int N; // size
vector<int> tree; // segment tree
void build(const int arr[], int size) {
N = size;
tree.resize(N * 4);
// (대상 배열, Tree idx, 부분합 배열 왼쪽 idx, 부분합 배열 오른쪽 idx)
buildRec(arr, 1, 0, N - 1);
}
. . .
Array를 Segment Tree에 입력 시, 2N-1 크기의 Vector 생성
이후 Recursive하게 Leaf Node에 닿을 때까지 계속 내려감
Query 과정
Array의 index 1 ~ index 4 까지의 합을 구해보자
int query(int left, int right) {
// 1st arg : arr 부분합의 처음 index
// 2nd arg : arr 부분합의 마지막 index
// 3rd arg : 구간합이 저장되어 있는 Tree의 index
// 4th arg : 3rd arg가 가리키는 구간합의 처음 index
// 5th arg : 3rd arg가 가리키는 구간합의 마지막 index
return queryRec(left, right, 1, 0, N - 1);
}
int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {
// 구하고자 하는 구간과 겹치는 부분이 하나도 없으면, 0을 return
if (right < nodeLeft || nodeRight < left)
return 0;
// nodeLeft와 nodeRight가 모두 구하고자 하는 구간 안에 있으면, 값을 더해준다
if (left <= nodeLeft && nodeRight <= right)
return tree[node];
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
return merge(queryRec(left, right, node * 2, nodeLeft, mid),
queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
}
탐색 시작 : Node = 1일 때
해당 구간(0 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과
3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색
Node = 3일 때
해당 구간(3 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과
3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색
Node = 7일 때
해당 구간(5 ~ 5)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 없으므로, 0 반환
Node = 6일 때
해당 구간(3 ~ 4)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 tree[6] 값 반환
Node = 2일 때
해당 구간(0 ~ 2)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과
3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색
Node = 5일 때
해당 구간(2 ~ 2)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 return인 tree[5] 값 반환
Node = 4일 때
해당 구간(0 ~ 1)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그런데 구하고자 하는 구간(1 ~ 4) 안 쪽에 있진 않아서, 두 번째 return 통과
3번째 return을 진행해야 하므로, merge의 2번째 argument부터 탐색
Node = 9일 때
해당 구간(0 ~ 1)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 있으므로, 첫 번째 return 통과
그리고 구하고자 하는 구간(1 ~ 4) 안 쪽에 있으므로, 2번째 return인 tree[9] 값 반환
Node = 8일 때
해당 구간(0 ~ 0)이 구하고자 하는 구간(1 ~ 4)과 겹치는 구간이 없으므로, 0 반환
Traverse를 마쳤으니, merge 함수에 의해 모든 함수의 결과값이 합쳐져서 21을 반환
Update 과정
// Array의 index를 newValue로 update했다!
int update(int index, int newValue) {
return updateRec(index, newValue, 1, 0, N - 1);
}
int updateRec(int index, int newValue, int node, int nodeLeft, int nodeRight) {
if (index < nodeLeft || nodeRight < index)
return tree[node];
if (nodeLeft == nodeRight)
return tree[node] = newValue;
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
int leftVal = updateRec(index, newValue, node * 2, nodeLeft, mid);
int rightVal = updateRec(index, newValue, node * 2 + 1, mid + 1, nodeRight);
return tree[node] = merge(leftVal, rightVal);
}
Index가 1인 부분의 값을 4로 변경 시의 과정을 살펴보자
Update 시작 : Node = 1일 때
저장된 구간합 부분(0 ~ 5)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과
단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과
따라서 좌측 구간합을 구하는 재귀 과정을 거친다.
Node = 2일 때
저장된 구간합 부분(0 ~ 2)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과
단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과
따라서 좌측 구간합을 구하는 재귀 과정을 거친다.
Node = 4일 때
저장된 구간합 부분(0 ~ 1)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과
단일값(NodeLeft == NodeRight)을 가리키지 않으므로, 두 번째 return 통과
따라서 좌측 구간합을 구하는 재귀 과정을 거친다.
Node = 8일 때
저장된 구간합 부분(0 ~ 0)이 변경된 index(1)를 포함하지 않으므로, tree[node] 반환
Node = 9일 때
저장된 구간합 부분(1 ~ 1)이 변경된 index(1)를 포함하므로, 첫 번째 return 통과
단일값(NodeLeft == NodeRight)을 가리키므로, tree[node] 값을 newValue로 update 후 반환
Node = 4일 때
좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환
Node = 5일 때
저장된 구간합 부분(2 ~ 2)이 변경된 index(1)를 포함하지 않으므로, tree[node] 반환
Node = 2일 때
좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환
Node = 3일 때
저장된 구간합 부분(3 ~ 5)이 변경된 index(1)를 포함하지 않으므로, tree[node] 값 반환
Node = 1일 때
좌측 구간합과 우측 구간합을 구했으므로, 둘이 합한 값(merge)을 tree[node]에 저장 후 값을 반환
전체 코드 구현
#include <iostream>
#include <vector>
using namespace std;
struct SegmentTree {
static const int DEFAULT_VALUE = 0;
//static const int DEFAULT_VALUE = numeric_limits<int>::max(); // for min
//static const int DEFAULT_VALUE = numeric_limits<int>::min(); // for max
// merge operation
int merge(int left, int right) {
return left + right; // sum
//return min(left, right); // min
//return max(left, right); // max
//...
}
int N; // size
vector<int> tree; // segment tree
void build(const int arr[], int size) {
N = size;
tree.resize(N * 4);
buildRec(arr, 1, 0, N - 1);
}
// inclusive
int update(int index, int newValue) {
return updateRec(index, newValue, 1, 0, N - 1);
}
// inclusive
int update(int left, int right, int newValue) {
return updateRec(left, right, newValue, 1, 0, N - 1);
}
// inclusive
int query(int left, int right) {
return queryRec(left, right, 1, 0, N - 1);
}
private:
int buildRec(const int arr[], int node, int nodeLeft, int nodeRight) {
if (nodeLeft == nodeRight)
return tree[node] = arr[nodeLeft];
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
int leftVal = buildRec(arr, node * 2, nodeLeft, mid);
int rightVal = buildRec(arr, node * 2 + 1, mid + 1, nodeRight);
return tree[node] = merge(leftVal, rightVal);
}
int updateRec(int index, int newValue, int node, int nodeLeft, int nodeRight) {
if (index < nodeLeft || nodeRight < index)
return tree[node];
if (nodeLeft == nodeRight)
return tree[node] = newValue;
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
int leftVal = updateRec(index, newValue, node * 2, nodeLeft, mid);
int rightVal = updateRec(index, newValue, node * 2 + 1, mid + 1, nodeRight);
return tree[node] = merge(leftVal, rightVal);
}
int updateRec(int left, int right, int newValue, int node, int nodeLeft, int nodeRight) {
if (right < nodeLeft || nodeRight < left)
return tree[node];
if (nodeLeft == nodeRight)
return tree[node] = newValue;
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
int leftVal = updateRec(left, right, newValue, node * 2, nodeLeft, mid);
int rightVal = updateRec(left, right, newValue, node * 2 + 1, mid + 1, nodeRight);
return tree[node] = merge(leftVal, rightVal);
}
int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {
if (right < nodeLeft || nodeRight < left)
return DEFAULT_VALUE; // default value
if (left <= nodeLeft && nodeRight <= right)
return tree[node];
int mid = nodeLeft + (nodeRight - nodeLeft) / 2;
return merge(queryRec(left, right, node * 2, nodeLeft, mid),
queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
}
};
int main() {
int arr[] = { 1, 3, 5, 7, 9, 11 };
int size = sizeof(arr) / sizeof(int);
SegmentTree st;
st.build(arr, size);
st.update(1, 4); // arr[1] : 3 -> 4
cout << st.query(1, 4) << endl; // 4 + 5 + 7 + 9 = 25
cout << st.query(0, 2) << endl; // 1 + 4 + 5 = 10
return 0;
}
시간복잡도
1. build() : O(N)
- 배열의 원소를 한 번씩 처리
- 총 2N-1개의 Node를 처리하므로, O(N)
2. query() : O(logN)
- 자세히 파고들면 끝이 없으니 Level (최대 logN) 마다 상수 시간의 node 탐색을 진행한다고 생각할 것
3. Update() : O(logN)
- Query의 시간복잡도와 비슷한 이유
값을 변경하지 않는 경우 (immutable) |
값을 변경하는 경우 (mutable) |
Prefix Sum Sparse Table |
Segment Tree Sqrt - Decomposition |
'IT_Study > C++' 카테고리의 다른 글
[C++] 프로그램 컴파일 과정 - 선행처리기(preprocessor), 컴파일러(compiler), 링커(linker) (0) | 2023.03.01 |
---|---|
[C++] Bit operation, 구조체를 활용한 Map 사용법 정리 (0) | 2023.02.23 |
[C++] Dynamic Programming (동적 계획법) 정리 (0) | 2023.02.21 |
[C++] Greedy Algorithm (욕심쟁이 알고리즘) 정리 (0) | 2023.02.20 |
[C++] Binary Search, Parametric Search, Two pointer 알고리즘을 활용한 문제 풀이 방법 (0) | 2023.02.19 |