Java

[자료구조] 세그먼트 트리 (Segment Tree)

에릭 Kim 2024. 2. 23. 11:09
반응형

 

세그먼트 트리란 ?

  • 배열 또는 리스트와 같은 데이터 구조를 이용하여 구간에 대한 질의를 효율적으로 처리하는 자료구조
  • 주어진 구간에 대한 쿼리 연산을 빠르게 수행할 수 있도록 도와줌. 주로 구간 합, 최솟값 최댓값 등.
  • 트리 구조를 사용하여 데이터를 분할하고 각 분할된 영역에 대한 요약정보를 저장. 이를 통해 트정 구간에 대한 연산을 빠르게 수행할 수 있음
  • 트리를 구축하는 초기 비용이 크지만, 이 후 구간 질의에 대해 빠르게 답을 제공할 수 있음으로 효율적

세그먼트 트리 구조

출처: https://yongj.in/data%20structure/Segment-Tree/


시간복잡도

  • 트리를 구축: O(N)
  • 합, 곱 계산(쿼리 연산): O(logN)
  • 값의 갱신(업데이트 연산): O(logN)

=> N은 배열의 크기이며, 구간에 대한 연산을 처리하는 경우 logN만큼의 시간이 들기 때문에 세그먼트 트리의 시잔복잡도는 O(logN)으로 볼 수 있습니다. 


그럼 세그먼트 트리를 왜 쓰나 ? 

누적합

  • 누적합의 시간 복잡도는 배열의 길이에 비례함. 배열을 N이라고 했을 때 누적합을 구하는 시간 복잡도는 O(N)인데, 이는 배열의 각 요소를 한번씩 순회하여 누적합을 계산하기 때문입니다. 
  • 따라서 배열의 크기게 선형적으로 비례하는 시간이 소요됩니다. 

동적 프로그래밍(DP)

  • DP의 시간복잡도는 문제의 특성과 구현 방식에 따라 달라짐. 하지만 일반적으로는 하위 문제들의 수와, 각 하위 문제를 해결하는 데 걸리는 시간에 따라 결정됨.
  • 메모이제이션을 통해 이미 해결된 문제들을 저장하고 재활용하는 경우, 각 하위 문제를 한번씩만 계산하게 되고, 하위 문제들의 수를 N이라고 할 때, 시간 복잡도는 O(N)이 나오게 됩니다. 

=> 세그먼트 트리가 훨씬 빠름 


구현

1. 트리 크기 설정

// 트리 크기 설정

// 생성자 함수로 설정할 때
 public SegmentTree(int arrSize) { 
            int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2)); // 트리 높이
            this.treeSize = (int) Math.pow(2,h+1); // 높이를 통해 배열 크기 구하기
            tree = new long[treeSize];
        }
        
// 약식으로 설정
tree = new long[N*4];

/*
N = 12 (데이터의 개수)
=> 세그먼트 트리의 전체 크기를 구하기 위해서는 2^k로 N보다 바로 큰 값을 만들 수 있는 k를 찾아야 함.
=> k는 4고, 2^k는 16. 16*2 => 32. 
=> 넉넉하게 N에 4를 곱해준 값으로 트리의 크기를 설정. 
*/

 

2. 트리 초기화

// 트리 초기화

/*
세그먼트 트리 초기화 단계
1. 트리의 루트(root) 노드부터 시작합니다.
2. 배열의 전체 구간을 표현하는 루트 노드의 구간을 계산하고 값을 할당합니다.
3. 재귀적으로 왼쪽 서브트리와 오른쪽 서브트리를 초기화합니다.
4. 각 노드의 구간을 계산하고 해당 구간에 속하는 배열 요소들의 값을 합산하여 현재 노드에 저장합니다.
5.리프(leaf) 노드까지 이 과정을 반복합니다.
*/

// tree: 원소 배열, node: 현재 노드, start: 현재 구간의 시작점, edn: 현재 구간의 끝점
public  long init(long[] tree, int node, int start, int end) {
	// 배열의 시작과 끝이 같으면 리프 노드! 
    if (start == end) return tree[node] = arr[start]; 

	// 아니면 자식 노드 합을 저장. 재귀 형식임 
    int mid = (start+end) / 2;
    return tree[node] = init(tree,node*2,start,mid) + init(tree,node*2+1,mid+1,end);
}

 

3. 값 갱신 

// 값 갱신 

public  void update(long[] tree, int node, int start, int end, int idx, long diff) {
    if (idx < start || idx > end) return; // 범위밖에 있는 경우

	// 값이 변경되는 차이만큼 원소 갱신
    tree[node] += diff; 

    int mid = (start+end) / 2;
    if (start != end) { // 재귀로 리프 노드까지 값 변경 진핸
        update(tree,node*2,start,mid,idx,diff);
        update(tree,node*2+1,mid+1,end,idx,diff);
    }
}

 

4. 구간 합 연산

// 구간 합 연산

/*
구간 합 연산 단계
1. 구간이 포함하는 노드들을 찾습니다. 이는 세그먼트 트리의 특성에 따라 수행됩니다.
2. 구간이 포함하는 노드들의 값을 합하여 구간의 부분 합을 계산합니다.
3. 필요에 따라 노드들을 재귀적으로 탐색하여 구간을 더 세분화하고 각 부분 구간의 합을 계산합니다.
4. 최종적으로 모든 부분 구간의 합을 더하여 전체 구간의 합을 얻습니다.
*/

public static long sum(int node, int start, int end, int left, int right) {
	if ( left > end || right < start) return 0; // 범위를 벗어나는 경우
		
	if (left <= start && end <= right) return tree[node]; //범위를 완전히 포함하는 경우
		
    // 재귀로 지속적으로 구간 값 연산 진행
	int mid = (start+end) / 2;
	return sum(node*2,start,mid,left,right) + sum(node*2+1,mid+1,end,left,right);
}

 


기본 문제

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

코드

import java.util.*;
import java.io.*;

class Main {
    static long[] arr;
    static int a,b;
    static long c;

    public static class SegmentTree {
        long[] tree;
        int treeSize;

        public SegmentTree(int arrSize) {
            int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
            this.treeSize = (int) Math.pow(2,h+1);
            tree = new long[treeSize];
        }

        public  long init(long[] tree, int node, int start, int end) {
            if (start == end) return tree[node] = arr[start];

            int mid = (start+end) / 2;
            return tree[node] = init(tree,node*2,start,mid) 
            		+ init(tree,node*2+1,mid+1,end);
        }

        public  void update(long[] tree, int node, int start, int end, int idx, long diff) {
            if (idx < start || idx > end) return;

            tree[node] += diff;

            int mid = (start+end) / 2;
            if (start != end) {
                update(tree,node*2,start,mid,idx,diff);
                update(tree,node*2+1,mid+1,end,idx,diff);
            }
        }

        public  long sum(long[] tree, int node, int start, int end, int left, int right) {
            if (left > end || right < start) return 0;

            if (left <= start && end <= right) {
                return tree[node];
            }

            int mid = (start+end) / 2;
            return sum(tree,node*2,start,mid,left,right) 
            		+ sum(tree,node*2+1,mid+1,end,left,right);
        }

    }

    public static void main(String[] args) throws Exception {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    StringTokenizer st = new StringTokenizer(br.readLine()," ");
    StringBuilder sb = new StringBuilder();

    int N = Integer.parseInt(st.nextToken());
    int M = Integer.parseInt(st.nextToken());
    int k = Integer.parseInt(st.nextToken());
    arr = new long[N+1];

    for (int i=1; i<=N;i++) {
        arr[i] = Long.parseLong(br.readLine());
    }

    SegmentTree sgTree = new SegmentTree(N);

    sgTree.init(sgTree.tree,1,1,N);

    for (int i=0; i< M+k;i++) {
        st = new StringTokenizer(br.readLine());
        a = Integer.parseInt(st.nextToken());
        b = Integer.parseInt(st.nextToken());
        c = Long.parseLong(st.nextToken());

        if (a == 1) { 
             sgTree.update(sgTree.tree,1,1,N,b,c-arr[b]);
             arr[b] = c; // update 연산시, 배열 arr도 함께 바꿔줘야 트리의 값을 변경할 때 반영됨
        }
        else {
            sb.append(sgTree.sum(sgTree.tree,1,1,N,b,(int) c)).append("\n");
        	}
    	}
    System.out.println(sb);
	}
}

 

 

Tips

  • 트리의 크기는 데이터의 개수 * 4로 설정하기 !!
  • 입력으로 초기값이 들어오지 않으면, 초기화하는 과정 필요없음 !! (init 메서드 생략) 
반응형