线段树(Segment Tree)是一种高级数据结构,主要用于在区间范围内高效地进行查询与修改操作。它是一个二叉树结构,每个节点代表一个区间的信息,通常用于解决如下问题:
线段树是对一个区间 [l, r]
上的数列进行划分,并在每个子区间上维护某种信息(如最值、和、最大公约数等)。
构建线段树是将区间不断对半划分,直到区间中只包含一个元素。每个节点保存其所代表区间的值(如区间和、最大值、最小值等)。
O(n)
O(log n)
O(log n)
O(log n)
线段树适用于以下问题:
场景 | 描述 |
---|---|
区间求和 | 查询某个区间内所有元素的和 |
区间最值 | 查询某个区间内最大值或最小值 |
区间更新 | 给某个区间内所有元素加/减某个值 |
区间 GCD | 查询某区间所有元素的最大公约数 |
动态维护序列统计信息 | 在元素更新的同时仍能高效查询 |
#include
#include
using namespace std;
class SegmentTree {
private:
vector<int> tree;
int n;
void build(vector<int>& nums, int node, int l, int r) {
if (l == r) {
tree[node] = nums[l];
} else {
int mid = (l + r) / 2;
build(nums, 2 * node + 1, l, mid);
build(nums, 2 * node + 2, mid + 1, r);
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
}
void update(int index, int value, int node, int l, int r) {
if (l == r) {
tree[node] = value;
} else {
int mid = (l + r) / 2;
if (index <= mid)
update(index, value, 2 * node + 1, l, mid);
else
update(index, value, 2 * node + 2, mid + 1, r);
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
}
int query(int ql, int qr, int node, int l, int r) {
if (qr < l || ql > r)
return 0; // 无交集
if (ql <= l && r <= qr)
return tree[node]; // 完全包含
int mid = (l + r) / 2;
int left_sum = query(ql, qr, 2 * node + 1, l, mid);
int right_sum = query(ql, qr, 2 * node + 2, mid + 1, r);
return left_sum + right_sum;
}
public:
SegmentTree(vector<int>& nums) {
n = nums.size();
tree.resize(4 * n);
build(nums, 0, 0, n - 1);
}
void update(int index, int value) {
update(index, value, 0, 0, n - 1);
}
int query(int ql, int qr) {
return query(ql, qr, 0, 0, n - 1);
}
};
class SegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n) # 开4倍空间
self._build(nums, 0, 0, self.n - 1)
def _build(self, nums, node, l, r):
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(nums, 2 * node + 1, l, mid)
self._build(nums, 2 * node + 2, mid + 1, r)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def update(self, index, value, node=0, l=0, r=None):
if r is None:
r = self.n - 1
if l == r:
self.tree[node] = value
else:
mid = (l + r) // 2
if index <= mid:
self.update(index, value, 2 * node + 1, l, mid)
else:
self.update(index, value, 2 * node + 2, mid + 1, r)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, ql, qr, node=0, l=0, r=None):
if r is None:
r = self.n - 1
if qr < l or ql > r:
return 0 # 无交集
if ql <= l and r <= qr:
return self.tree[node] # 完全包含
mid = (l + r) // 2
left_sum = self.query(ql, qr, 2 * node + 1, l, mid)
right_sum = self.query(ql, qr, 2 * node + 2, mid + 1, r)
return left_sum + right_sum
int main() {
vector<int> nums = {1, 3, 5, 7, 9, 11};
SegmentTree seg(nums);
cout << "区间 [1, 3] 的和: " << seg.query(1, 3) << endl; // 输出 15
seg.update(1, 10); // 将 nums[1] 改为 10
cout << "更新后区间 [1, 3] 的和: " << seg.query(1, 3) << endl; // 输出 22
return 0;
}
nums = [1, 3, 5, 7, 9, 11]
seg = SegmentTree(nums)
print(seg.query(1, 3)) # 输出:15(3 + 5 + 7)
seg.update(1, 10)
print(seg.query(1, 3)) # 输出:22(10 + 5 + 7)
(l + r) // 2
、递归终止条件。