CS/자료구조

[자료구조] 바이너리 인덱스 트리 (Binary Indexed Tree)

meizzi 2024. 1. 30. 21:03
728x90
반응형

1. 바이너리 인덱스 트리 (Binary Indexed Tree)

  • 2진법 인덱스 구조를 활용하여 구간 합 문제를 효과적으로 해결해줄 수 있는 자료구조
  • 펜윅 트리(fenwick tree) 라고도 함
  • 정수에 따른 2진수 표기
    정수 2진수 표기
    7 00000000 00000000 00000000 00000111
    -7 11111111 11111111 11111111 11111001
  • 0이 아닌 마지막 비트를 찾는 방법
    • 특정한 숫자 K의 0이 아닌 마지막 비트를 찾기 위해서 K & -K 계산
  • K & -K 계산 결과 예시
    정수 K 2진수 표기 K & -K
    0 00000000  00000000  00000000 00000000 0
    1 00000000  00000000  00000000 00000001 1
    2 00000000  00000000  00000000 00000010 2
    3 00000000  00000000  00000000 00000011 1
    4 00000000  00000000  00000000 00000100 4
    5 00000000  00000000  00000000 00000101 1
    6 00000000  00000000  00000000 00000110 2
    7 00000000  00000000  00000000 00000111 1
    8 00000000  00000000  00000000 00001000 8

2. 트리 구조 만들기

  • 0이 아닌 마지막 비트 = 내가 저장하고 있는 값들의 개수

3. 업데이트 (Update)

  • 특정 값을 변경할 때
    • 0이 아닌 마지막 비트만큼 더하면서 구간들의 값을 변경 (예시 = 3rd)
    • 최악의 경우에도 시간 복잡도 O(logN) 보장
      3, 4, 8, 16

4. 누적 합(Prefix Sum)

  • 1부터 N까지의 합(누적 합) 구하기
    • 0이 아닌 마지막 비트만큼 빼면서 구간들의 값의 합 계산 (예시 = 11th)
    • 최악의 경우에도 시간 복잡도 O(logN) 보장

5. 바이너리 인덱스 트리 구현

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

import sys
input = sys.stdin.readline

# n: 데이터의 개수, m: 변경 횟수, 구간 합 계산 횟수: k
n, m, k = map(int, input().split())

# 전체 데이터의 개수는 최대 1,000,000개
arr = [0] * (n+1)
tree = [0] * (n+1)

# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        i -= (i & -i) # 0이 아닌 마지막 비트만큼 빼가면서 이동
    return result

# i번째 수를 dif 만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

# start부터 end까지의 구간 합을 계산하는 함수
def intercal_sum(start, end):
    return prefix_sum(end) - prefic_sum(start - 1)

for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())
    
    # 업데이트(update) 연산인 경우
    if a == 1:
        update(b, c - arr[b]) # 바뀐 크기(dif)만큼 적용
        arr[b] = c
    # 구간 합(interval sum) 연산인 경우
    else:
        print(interval_sum(b, c))
728x90
반응형