말랑한 하루

[Algorithm] 세그먼트 트리 (Segment Tree) 본문

Algorithm

[Algorithm] 세그먼트 트리 (Segment Tree)

지수는말랑이 2020. 12. 10. 23:21
반응형

[ 개념 ]

이진트리로 구성된 정적트리

[ 시간복잡도 ]

O(n logn)

[ 흐름 ]

(1) 최상위 Root Node가 1부터 시작함

→ 자식노드를 찾아가기 편하게 하기위해서

(2) 자식노드들은 부모노드의 2배이거나 2배+1노드로 분배

→ 왼쪽자식 Left_node = root_node *2 , 오른쪽자식 Right_node = root_node*2+1

(3) 세그먼트트리 전체 노드의 개수는 배열의크기보다 큰 제곱근 중 가장 작은수의 두배이다

→ ArySize = 9 일때, 9<16 (3^2 < 4^2) 이므로 전체크기 = 16*2 = 32

→ 다음과같은 식으로 커버할 수 있다.

#include <cmath>
int ArySize=9;
int height = (int)ceil(log2(ArySize));
int TreeSize = 1 << (height+1)

→ 또는 전체사이즈의 4배를 해주면 어떤경우에라도 모든범위를 커버할 수 있다.

→ ArySize (=9) * 4 =  36 (전체트리사이즈)

 

[ 핵심소스코드 / C++ ]

int ary[];
int segment_tree[];
//or vector <int> segment_tree;
//ㅡㅡㅡㅡㅡㅡㅡㅡㅡSegment Tree Inintㅡㅡㅡㅡㅡㅡㅡㅡㅡ 초기화
*init(0, arraysize-1, 1); // 배열의 시작인덱스, 끝인덱스

int init(int start, int end, int node) {
	// 해당인덱스부분에 도착하면 값삽입
	if (start==end) return tree[node] = ary[start];
	// 해당 부분이 아니라면 두 부분으로 나누는 기준점잡고 나누기
	int mid = (start+end)/2;
	// 두개의 자식으로나눈뒤 그들의 합을 자신으로 하도록 설정
	return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1)

	}
//ㅡㅡㅡㅡㅡㅡㅡㅡㅡSegment Tree Sumㅡㅡㅡㅡㅡㅡㅡㅡㅡ 구간쿼리
*0~12까지의 합 : sum(0, arraysize-1, 1, 0, 12)

int sum(int start, int end, int node, int left, int right) {
	// 범위 밖으로 left와 right 가 나갈경우
	if (left > end || right < start) return 0;
	// 범위 안에 존재할경우
	if (left <= start && end <= right) return tree[node];
	// 그렇지않을경우 두부분으로 나누어서 합구하기
	int mid = (start+end)/2;
	return sum(start, mid, node*2, left, right)+sum(mid+1, end, node*2+1, left, right);
}
//ㅡㅡㅡㅡㅡㅡㅡㅡㅡSegment Tree Updateㅡㅡㅡㅡㅡㅡㅡㅡㅡ 구간업데이트
*3번째 인덱스를 5라는 값으로 치환할때,
pix = abs(ary[3] - 5);
update(0, arraysize-1, 1, 3, pix/*=abs(ary[3]-5)*/);

void update(int start, int end, int node, int index, int pix) {
	// 인덱스가 범위밖일경우 
	if (index < start || index > end ) return 0;
	// 인덱스가 범위 안일경우, 내려가면서 해당 인덱스까지 모든경로를 업데이트
	// 인덱스 자체값 수정이아닌, 인덱스 값에서 일정크기만큼을 수정하는것
	tree[node]+=pix;
	int mid = (start+end)/2;
	update(start, mid, node*2, index, pix);
	update(mid+1, end, node*2+1, index, pix);
}

int main() {
	int total = arraysize;
    int height = (int)ceil(log2(total));
    int TreeSize = i << (height+1);
    // if segmentTree is vector,
    segment_tree.resize(TreeSize);
}

[ 핵심소스코드 / Java ] 

자바엔 C++의 log2()가 존재하지않다. 따라서 log의 성질질을이용해 log2를 구현해야한다

log2() 는 밑이 2이고 지수가 ()인 로그이다. 이를 밑이 10인 로그로 나타낼때 log10() / log10(2)로 변형됨을 이용한다.

Math.ceil(Math.log(N) / Math.log(2)

귀찮으면 그냥 입력받은 수에 X4하면 편하다.

import java.io.*;
import java.util.*;

class SegmentTree {
    int ary[];
    long segment_tree[];

    SegmentTree(int N) {
        ary = new int[N];
        int height = (int) Math.ceil(Math.log(N) / Math.log(2));
        int TreeSize = 1 << (height + 1);
        segment_tree = new long[TreeSize];
    }
    public long make(int start, int end, int node) {
        if (start==end) return segment_tree[node]=ary[start];
        int mid = (start+end)/2;
        return segment_tree[node] = make(start, mid, node*2)+make(mid+1, end, node*2+1); 
    }
    public long query(int start, int end, int node, int left, int right) {
        if (left > end || right < start) return 0;
        if (left<=start && end <= right) return segment_tree[node];
        int mid = (start+end)/2;
        return query(start, mid, node*2, left, right)+query(mid+1, end, node*2+1, left, right);
    }
    public void update(int start, int end, int node, int index, int value) {
        if (index > end || index < start) return;
        segment_tree[node]+=value;
        if (start==end)return;
        int mid = (start+end)/2;
        update(start, mid, node*2, index, value);
        update(mid+1, end, node*2+1, index, value);
    }
    public void print() {
        for(int i=1;i<this.segment_tree.length;i++) {
            System.out.println(this.segment_tree[i]);
        }
    }
}

public class Main {
    public static void main(String[] args) throws IOException {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        SegmentTree st = new SegmentTree(N);
    }
}
반응형
Comments