Segment & Fenwick Trees
Advanced Structures for Range Queries
When you need to perform operations on a range of array elements repeatedly and also need to handle updates to the array, Segment Trees and Fenwick Trees (BIT) are the go-to data structures. They offer O(log n) performance for both updates and queries, a huge improvement over the O(n) approach of prefix sums when updates are frequent.
1. Fenwick Tree (Binary Indexed Tree)
A Fenwick Tree, or BIT, is an elegant data structure that excels at handling prefix sum queries with point updates. It's easier to code and uses less memory than a Segment Tree, making it a great choice for problems that fit its structure.
The Core Idea
A BIT cleverly uses the binary representation of indices. Each node in the tree (which is just an index in an array) is responsible for a range of elements. The size of this range is determined by the value of the least significant bit (LSB) of its index. This allows for both queries and updates to be performed by traversing a logarithmic number of nodes.
Time: O(log n) for point update & prefix query | Space: O(n)
class FenwickTree:
def __init__(self, size):
self.tree = [0] * (size + 1)
def update(self, i, delta):
# i is 1-based index
while i < len(self.tree):
self.tree[i] += delta
i += i & -i # Add LSB
def query(self, i):
# Get prefix sum up to i (1-based)
s = 0
while i > 0:
s += self.tree[i]
i -= i & -i # Subtract LSB
return s
2. Segment Tree
A Segment Tree is a more versatile binary tree used for storing information about intervals or segments. Each leaf node represents an element of the input array, and each internal node represents a union (like sum, min, max, gcd) of its children's segments.
Structure and Operations
- Build: The tree is built recursively by dividing the array into two halves until we reach individual elements. The internal nodes store the result of the operation on their range. Time: O(n).
- Query: To query a range [L, R], we traverse the tree. If a node's range is completely within [L, R], we use its value. If it partially overlaps, we recurse on its children. Time: O(log n).
- Point Update: To update an element, we update the leaf and then recursively update all its ancestors up to the root. Time: O(log n).
Time: O(log n) for query & point update | Space: O(4n) ~ O(n)
class SegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self.build(nums, 0, 0, self.n - 1)
def build(self, nums, node, start, end):
if start == end:
self.tree[node] = nums[start]
return
mid = (start + end) // 2
self.build(nums, 2 * node + 1, start, mid)
self.build(nums, 2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def update(self, idx, val, node=0, start=0, end=None):
if end is None: end = self.n - 1
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if start <= idx <= mid:
self.update(idx, val, 2 * node + 1, start, mid)
else:
self.update(idx, val, 2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, l, r, node=0, start=0, end=None):
if end is None: end = self.n - 1
if r < start or end < l:
return 0
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
p1 = self.query(l, r, 2 * node + 1, start, mid)
p2 = self.query(l, r, 2 * node + 2, mid + 1, end)
return p1 + p2
3. Segment Tree with Lazy Propagation
The real power of Segment Trees is unlocked with Lazy Propagation. This technique allows for efficient range updates (e.g., "add 5 to all elements from index L to R"). A naive approach would take O(n). With lazy propagation, it also takes O(log n).
How it Works
When an update applies to a node's entire range, we don't immediately push the update to all its children. Instead, we store the pending update in a special "lazy" array for that node and update the node's main value. We only "propagate" the lazy update to its children when a future query or update needs to access them.
class LazySegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n)
self.build(nums, 0, 0, self.n - 1)
def build(self, nums, node, start, end):
if start == end:
self.tree[node] = nums[start]
return
mid = (start + end) // 2
self.build(nums, 2 * node + 1, start, mid)
self.build(nums, 2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def _push(self, node, start, end):
if self.lazy[node] != 0:
self.tree[node] += (end - start + 1) * self.lazy[node]
if start != end:
self.lazy[2 * node + 1] += self.lazy[node]
self.lazy[2 * node + 2] += self.lazy[node]
self.lazy[node] = 0
def updateRange(self, l, r, val, node=0, start=0, end=None):
if end is None: end = self.n - 1
self._push(node, start, end)
if start > end or start > r or end < l:
return
if l <= start and end <= r:
self.lazy[node] += val
self._push(node, start, end)
return
mid = (start + end) // 2
self.updateRange(l, r, val, 2 * node + 1, start, mid)
self.updateRange(l, r, val, 2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def queryRange(self, l, r, node=0, start=0, end=None):
if end is None: end = self.n - 1
if start > end or start > r or end < l:
return 0
self._push(node, start, end)
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
p1 = self.queryRange(l, r, 2 * node + 1, start, mid)
p2 = self.queryRange(l, r, 2 * node + 2, mid + 1, end)
return p1 + p2
4. Classic Problems
Range Sum Query - Mutable
This is the canonical problem for both Fenwick Trees and Segment Trees. You need to handle both point updates and range sum queries efficiently.
class NumArray:
def __init__(self, nums: List[int]):
self.n = len(nums)
self.bit = [0] * (self.n + 1)
self.original_nums = [0] * self.n
for i, num in enumerate(nums):
self.update(i, num)
def _update_bit(self, i, delta):
i += 1 # 1-based index
while i <= self.n:
self.bit[i] += delta
i += i & -i
def update(self, index: int, val: int) -> None:
delta = val - self.original_nums[index]
self.original_nums[index] = val
self._update_bit(index, delta)
def _query_bit(self, i):
i += 1 # 1-based index
s = 0
while i > 0:
s += self.bit[i]
i -= i & -i
return s
def sumRange(self, left: int, right: int) -> int:
return self._query_bit(right) - self._query_bit(left - 1)
Range Sum Query 2D - Mutable
The 2D version of the problem. This can be solved using a 2D Fenwick Tree or a 2D Segment Tree, where you essentially have a tree of trees. Updates and queries become O(log n * log m).
class NumMatrix:
def __init__(self, matrix: List[List[int]]):
if not matrix or not matrix[0]:
self.rows, self.cols = 0, 0
return
self.rows, self.cols = len(matrix), len(matrix[0])
self.tree = [[0] * (self.cols + 1) for _ in range(self.rows + 1)]
self.matrix = [[0] * self.cols for _ in range(self.rows)]
for r in range(self.rows):
for c in range(self.cols):
self.update(r, c, matrix[r][c])
def update(self, row: int, col: int, val: int) -> None:
delta = val - self.matrix[row][col]
self.matrix[row][col] = val
r, c = row + 1, col + 1
while r <= self.rows:
c_temp = c
while c_temp <= self.cols:
self.tree[r][c_temp] += delta
c_temp += c_temp & -c_temp
r += r & -r
def _query(self, row: int, col: int) -> int:
s = 0
r, c = row + 1, col + 1
while r > 0:
c_temp = c
while c_temp > 0:
s += self.tree[r][c_temp]
c_temp -= c_temp & -c_temp
r -= r & -r
return s
def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
return (self._query(row2, col2) -
self._query(row1 - 1, col2) -
self._query(row2, col1 - 1) +
self._query(row1 - 1, col1 - 1))