【Algorithm】Segment Tree 简单介绍

文章目录

  • Segment Tree
    • 1 基本概念
    • 2 基本思想
    • 3 适用场景
    • 4 代码示例(区间求和)
    • 5 使用示例
    • 6 使用注意事项
    • 7 进阶拓展

Segment Tree

线段树(Segment Tree)是一种高级数据结构,主要用于在区间范围内高效地进行查询与修改操作。它是一个二叉树结构,每个节点代表一个区间的信息,通常用于解决如下问题:


1 基本概念

线段树是对一个区间 [l, r] 上的数列进行划分,并在每个子区间上维护某种信息(如最值、和、最大公约数等)。

  • 节点表示一个区间
  • 根节点表示整个区间
  • 每个节点的两个子节点分别表示当前区间的左右子区间

2 基本思想

构建线段树是将区间不断对半划分,直到区间中只包含一个元素。每个节点保存其所代表区间的值(如区间和、最大值、最小值等)。

  • 建树复杂度:O(n)
  • 单点更新复杂度:O(log n)
  • 区间查询复杂度:O(log n)
  • 区间修改(懒标记):O(log n)

3 适用场景

线段树适用于以下问题:

场景 描述
区间求和 查询某个区间内所有元素的和
区间最值 查询某个区间内最大值或最小值
区间更新 给某个区间内所有元素加/减某个值
区间 GCD 查询某区间所有元素的最大公约数
动态维护序列统计信息 在元素更新的同时仍能高效查询

4 代码示例(区间求和)

  • C++版本:
#include 
#include 
using namespace std;

class SegmentTree {
private:
    vector<int> tree;
    int n;

    void build(vector<int>& nums, int node, int l, int r) {
        if (l == r) {
            tree[node] = nums[l];
        } else {
            int mid = (l + r) / 2;
            build(nums, 2 * node + 1, l, mid);
            build(nums, 2 * node + 2, mid + 1, r);
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

    void update(int index, int value, int node, int l, int r) {
        if (l == r) {
            tree[node] = value;
        } else {
            int mid = (l + r) / 2;
            if (index <= mid)
                update(index, value, 2 * node + 1, l, mid);
            else
                update(index, value, 2 * node + 2, mid + 1, r);
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

    int query(int ql, int qr, int node, int l, int r) {
        if (qr < l || ql > r)
            return 0;  // 无交集
        if (ql <= l && r <= qr)
            return tree[node];  // 完全包含
        int mid = (l + r) / 2;
        int left_sum = query(ql, qr, 2 * node + 1, l, mid);
        int right_sum = query(ql, qr, 2 * node + 2, mid + 1, r);
        return left_sum + right_sum;
    }

public:
    SegmentTree(vector<int>& nums) {
        n = nums.size();
        tree.resize(4 * n);
        build(nums, 0, 0, n - 1);
    }

    void update(int index, int value) {
        update(index, value, 0, 0, n - 1);
    }

    int query(int ql, int qr) {
        return query(ql, qr, 0, 0, n - 1);
    }
};
  • python 版本
class SegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)  # 开4倍空间
        self._build(nums, 0, 0, self.n - 1)

    def _build(self, nums, node, l, r):
        if l == r:
            self.tree[node] = nums[l]
        else:
            mid = (l + r) // 2
            self._build(nums, 2 * node + 1, l, mid)
            self._build(nums, 2 * node + 2, mid + 1, r)
            self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

    def update(self, index, value, node=0, l=0, r=None):
        if r is None:
            r = self.n - 1
        if l == r:
            self.tree[node] = value
        else:
            mid = (l + r) // 2
            if index <= mid:
                self.update(index, value, 2 * node + 1, l, mid)
            else:
                self.update(index, value, 2 * node + 2, mid + 1, r)
            self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

    def query(self, ql, qr, node=0, l=0, r=None):
        if r is None:
            r = self.n - 1
        if qr < l or ql > r:
            return 0  # 无交集
        if ql <= l and r <= qr:
            return self.tree[node]  # 完全包含
        mid = (l + r) // 2
        left_sum = self.query(ql, qr, 2 * node + 1, l, mid)
        right_sum = self.query(ql, qr, 2 * node + 2, mid + 1, r)
        return left_sum + right_sum

5 使用示例

  • C++版本
int main() {
    vector<int> nums = {1, 3, 5, 7, 9, 11};
    SegmentTree seg(nums);

    cout << "区间 [1, 3] 的和: " << seg.query(1, 3) << endl;  // 输出 15
    seg.update(1, 10);  // 将 nums[1] 改为 10
    cout << "更新后区间 [1, 3] 的和: " << seg.query(1, 3) << endl;  // 输出 22

    return 0;
}
  • python版本
nums = [1, 3, 5, 7, 9, 11]
seg = SegmentTree(nums)
print(seg.query(1, 3))  # 输出:15(3 + 5 + 7)
seg.update(1, 10)
print(seg.query(1, 3))  # 输出:22(10 + 5 + 7)

6 使用注意事项

  1. 空间消耗大:为了避免越界,需要预分配约 4 倍原始数组大小的空间。
  2. 边界细节多:区间划分时容易出错,如 (l + r) // 2、递归终止条件。
  3. 懒标记复杂度高:当要进行区间修改时,引入懒惰标记(Lazy Propagation)会使实现变复杂。
  4. 更新频率高适合使用树状数组(Binary Indexed Tree):如只支持前缀和。

7 进阶拓展

  • 懒惰标记(Lazy Propagation):支持区间更新,优化大量重复更新操作。
  • 可持久化线段树:支持版本控制、历史查询。
  • 线段树合并:用于合并多个线段树数据。
  • 二维线段树:处理二维矩阵中的区间查询。

你可能感兴趣的:(C/C++,c++,算法,python)