Disjoint Set (Union-Find)
Managing Disjoint Sets Efficiently
The Disjoint Set Union (DSU), also known as Union-Find, is a specialized data structure for tracking a collection of disjoint (non-overlapping) sets. It's famous for its near-constant time complexity per operation when fully optimized, making it ideal for problems involving connectivity and grouping.
1. The Core Idea
Imagine you have a group of people and you want to keep track of their friendships. Initially, everyone is in their own group of one. When two people become friends, you merge their groups. The DSU answers two questions very quickly:
- Find: Are these two people in the same friend group?
- Union: Make these two people (and their entire respective friend groups) become one large friend group.
We represent these sets as trees. Each set has a single "representative" or "root" node. To check if two elements are in the same set, we just check if they have the same root.
2. Naive Implementation
A simple DSU can be implemented with a parent
array, where parent[i]
stores the parent of element i
. If parent[i] == i
, then i
is the root of its set.
- Find(i): Traverse up from
i
using the parent array until you reach the root. - Union(i, j): Find the root of
i
(say,root_i
) and the root ofj
(say,root_j
). If they are different, set the parent of one root to be the other (e.g.,parent[root_i] = root_j
).
However, this can lead to tall, skinny trees that look like linked lists, making the find
operation take O(n) time in the worst case.
3. The Optimizations (Crucial!)
Two key optimizations transform DSU into a nearly constant-time powerhouse.
Path Compression
During a find(i)
operation, after we've found the ultimate root of i
, we re-traverse the path from i
to the root and make every node on that path point directly to the root. This dramatically flattens the tree for future lookups.
Union by Size/Rank
When performing a union
, instead of arbitrarily connecting the trees, we use a heuristic to keep them as flat as possible.
- Union by Size: Track the number of nodes in each set. Always attach the root of the smaller tree to the root of the larger tree.
- Union by Rank: Track the "rank" (an upper bound on the tree's height). Always attach the shorter tree to the root of the taller one.
When both optimizations are used, the time complexity for m
operations becomes O(m * α(n)), where α(n) is the extremely slow-growing inverse Ackermann function, which is less than 5 for any practical value of n
.
class DSU:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.num_components = n
def find(self, i):
if self.parent[i] == i:
return i
# Path compression
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Union by size
if self.size[root_i] < self.size[root_j]:
root_i, root_j = root_j, root_i
self.parent[root_j] = root_i
self.size[root_i] += self.size[root_j]
self.num_components -= 1
return True
return False
4. Classic Problems
Number of Connected Components
Given n
nodes and a list of edges, find the number of connected components. You can initialize a DSU with n
components. For each edge (u, v)
, perform a union operation. The number of times the union is successful (i.e., when u
and v
were not already connected) is the number of merges. The final number of components is n - successful_unions
.
from typing import List
class Solution:
def countComponents(self, n: int, edges: List[List[int]]) -> int:
dsu = DSU(n)
for u, v in edges:
dsu.union(u, v)
return dsu.num_components
Redundant Connection
In this problem, you are given a graph that started as a tree with n
nodes and had one extra edge added. Find the edge that can be removed so that the graph becomes a tree. DSU is perfect for this. Iterate through the edges. For each edge (u, v)
, check if u
and v
are already in the same set using find
. If they are, this edge is redundant because adding it would form a cycle. Since the problem guarantees a single redundant edge, this is the one to return.
from typing import List
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
dsu = DSU(len(edges) + 1)
for u, v in edges:
if not dsu.union(u, v):
return [u, v]