权值线段树和可持久化线段树(主席树)

目录

  • 权值线段树
    • 权值线段树的基本概念
    • 权值线段树的构建
    • 权值线段树的操作
      • 添加元素
      • 查询区间 [l, r] 的元素个数
      • 查询整个集合中第 k 小(或第 k 大)的元素值
    • 例题
    • 代码实现
  • 可持久化线段树(主席树)
    • 例题1
      • 代码实现
    • 例题2
      • 思路
      • 代码实现

权值线段树

权值线段树的基本概念

权值线段树是一种特殊的线段树,它的叶子节点存储的是某个元素的权值(通常是该元素的出现次数),而不是元素本身。分支节点则存储其子节点权值的某种集合值(如和、最值等)。

举例说明
假设我们有一个数组 [1, 2, 1, 3, 3, 5],我们可以通过统计每个元素的出现次数来构建一个“桶”:
元素 1 出现了 2 次
元素 2 出现了 1 次
元素 3 出现了 2 次
元素 4 出现了 0 次
元素 5 出现了 1 次
因此,我们得到的桶为 [2, 1, 2, 0, 1],表示每个数字出现的次数。权值线段树就是基于这个桶来建立和维护的。

权值线段树的构建

构建过程
确定区间范围:首先,我们需要确定元素的取值范围。假设元素的取值范围是 [1, n],那么我们可以将线段树的叶子节点对应到每个元素的值。

初始化桶:根据给定的数组,统计每个元素的出现次数,得到一个桶。

构建线段树:基于这个桶,构建线段树。每个叶子节点存储对应元素的出现次数,分支节点存储其子节点的权值之和。

示例
假设我们有数组 [1, 2, 1, 3, 3, 5],元素的取值范围是 [1, 5]。我们首先统计每个元素的出现次数,得到桶 [2, 1, 2, 0, 1]。然后,我们基于这个桶构建线段树。

权值线段树的操作

添加元素

向权值线段树中添加一个元素,类似于线段树的单点更新操作。

定位到叶子节点:根据元素的值,定位到对应的叶子节点。

更新权值:将该叶子节点的权值加 1。

更新父节点:递归更新其父节点的权值,直到根节点。

查询区间 [l, r] 的元素个数

查询区间 [l, r] 的元素个数,类似于线段树的区间查询操作。

递归查询:从根节点开始,递归查询区间 [l, r]。

合并结果:如果当前节点的区间完全包含在 [l, r] 内,则直接返回该节点的权值;否则,递归查询左右子节点,并将结果合并。

查询整个集合中第 k 小(或第 k 大)的元素值

从根节点开始:每个节点记录了元素大小在 [s, e] 之间的元素数量。

判断左子树的元素数量:设左子树的元素数量为 leftsum

如果 leftsum >= k,说明第 k 小的元素在左子树中,递归查询左子树。

如果 leftsum < k,说明第 k 小的元素在右子树中,递归查询右子树,并将 k 减去 leftsum

到达叶子节点:当递归到某个叶子节点时,该叶子节点对应的元素值即为第 k 小的元素。

例题

题目描述

你有一个空的多重集合(允许元素多次出现),执行以下操作共 q q q 次:

  1. 1 x:向集合中新增元素 x x x
  2. 2 l r:查询大小在 [ l , r ] [l,r] [l,r] 的元素的个数之和。
  3. 3 k:查询集合中第 k k k 小的元素,保证 k k k 小于等于此时集合大小。

输入描述
第一行一个整数表示 q q q 1 ≤ n , q ≤ 2 × 1 0 5 1 \leq n, q \leq 2 \times 10^5 1n,q2×105)。

接下来 q q q 行,每行一个操作。
1 ≤ o p ≤ 3 , 1 ≤ l ≤ r ≤ n , 1 ≤ x , k ≤ n 1 \leq op \leq 3, 1 \leq l \leq r \leq n, 1 \leq x, k \leq n 1op3,1lrn,1x,kn

输出描述
对于每次操作 23,输出一行结果。

输入样例

6
1 1
1 2
2 1 2
1 1
2 1 2
3 3

输出样例

2
3
2

代码实现

#include
using namespace std;

const int N = 2e5 + 5;
int val[N << 2];

void insert(int x, int s, int e, int idx) {
    if (s == e)
        return val[idx]++, void();
    int mid = (s + e) >> 1;
    if (x <= mid)
        insert(x, s, mid, idx << 1);
    else
        insert(x, mid + 1, e, idx << 1 | 1);
    val[idx] = val[idx << 1] + val[idx << 1 | 1];
}

int queryCnt(int l, int r, int s, int e, int idx) {
    if (l <= s && e <= r)
        return val[idx];
    int mid = (s + e) >> 1;
    int res = 0;
    if (l <= mid)
        res += queryCnt(l, r, s, mid, idx << 1);
    if (mid + 1 <= r)
        res += queryCnt(l, r, mid + 1, e, idx << 1 | 1);
    return res;
}

int queryVal(int k, int s, int e, int idx) {
    if (s == e)
        return s;
    int mid = (s + e) >> 1;
    int left_num = val[idx << 1];
    if (left_num >= k)
        return queryVal(k, s, mid, idx << 1);
    else
        return queryVal(k - left_num, mid + 1, e, idx << 1 | 1);
}

void solve() {
    int q, n = 2e5;
    cin >> q;
    while (q--) {
        int op;
        cin >> op;
        if (op == 1) {
            int x;
            cin >> x;
            insert(x, 1, n, 1);
        } else if (op == 2) {
            int l, r;
            cin >> l >> r;
            cout << queryCnt(l, r, 1, n, 1) << '\n';
        } else {
            int k;
            cin >> k;
            cout << queryVal(k, 1, n, 1) << '\n';
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    solve();
    return 0;
}

可持久化线段树(主席树)

与上述线段树思路一样,通过桶计数来维护第 k k k 小,不过要维护前缀和关系,通过类似 t[r] - t[l] 这样的关系求出区间 [l, r] 的第 k k k 小,因此要可持久化,来保留历史版本 t[l]t[r]。而上述不可持久化的线段树,只能存有当前最新的版本, t[n] 即整个集合的第 k k k 小,无法存下历史版本。

例题1

题目描述

给定一个长度为 n n n 的数组 a a a,有 q q q 次询问,每次询问区间 [ l , r ] [l, r] [l,r] 中排名为 k k k 的元素值(即第 k k k 小的元素)。

输入格式

  • 第一行两个整数 n n n q q q ( 1 ≤ n , q ≤ 2 × 1 0 5 1 \leq n, q \leq 2 \times 10^5 1n,q2×105)。
  • 第二行 n n n 个整数,表示数组 a a a ( 1 ≤ a i ≤ 1 0 9 1 \leq a_i \leq 10^9 1ai109)。
  • 接下来 q q q 行,每行一个询问 l l l, r r r, k k k ( 1 ≤ l ≤ r ≤ n 1 \leq l \leq r \leq n 1lrn, 1 ≤ k ≤ r − l + 1 1 \leq k \leq r - l + 1 1krl+1)。

输出格式

对于每次询问,输出一行结果。

样例输入

6 3
1 3 2 5 4 6
1 3 2
2 5 3
3 6 4

样例输出

2
4
6

代码实现

#include
using namespace std;

const int N=1e6 + 9;
int n, q;
int a[N];
int rt[N], idx;
vector<int> X;

int bin(int x) {
    return lower_bound(X.begin(), X.end(), x) - X.begin() + 1;
}

struct node {
    int ls, rs, val;
} t[N * 30];

void insert(int& o, int pre, int val, int s = 1, int e = n) {
    o = ++idx;
    t[o] = t[pre];
    t[o].val++;
    if (s == e)
        return;
    int mid = (s + e) >> 1;
    if (val <= mid)
        insert(t[o].ls, t[pre].ls, val, s, mid);
    else
        insert(t[o].rs, t[pre].rs, val, mid + 1, e);
}

int queryVal(int lo, int ro, int k, int s = 1, int e = n) {
    if (s == e)
        return s;
    int left_sum = t[t[ro].ls].val - t[t[lo].ls].val;
    int mid = (s + e) >> 1;
    if (k <= left_sum)
        return queryVal(t[lo].ls, t[ro].ls, k, s, mid); 
    return queryVal(t[lo].rs, t[ro].rs, k - left_sum, mid + 1, e);
}

void solve() {
    cin >> n >> q;
    for (int i = 1; i <= n; ++i)
        cin >> a[i];

    // 离散化
    X.resize(n);
    for (int i = 1; i <= n; ++i)
        X[i - 1] = a[i];
    sort(X.begin(), X.end());
    X.erase(unique(X.begin(), X.end()), X.end());

    for (int i = 1; i <= n; ++i)
        insert(rt[i], rt[i - 1], bin(a[i]));
    while (q--) {
        int l, r, k;
        cin >> l >> r >> k;
        cout << X[queryVal(rt[l - 1], rt[r], k) - 1] << '\n';
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    solve();
    return 0;
}

例题2

问题描述

给定一个长度为 n n n 的数组 a a a,有 q q q 次询问,每次询问区间 [ l , r ] [l, r] [l,r] 中不同的数字的个数。

输入格式

  • 第一行两个整数 n n n q q q,其中 1 ≤ n ≤ 2 × 1 0 5 1 \leq n \leq 2 \times 10^5 1n2×105 1 ≤ q ≤ 2 × 1 0 5 1 \leq q \leq 2 \times 10^5 1q2×105
  • 第二行输入 n n n 个整数表示数组 a a a,其中 1 ≤ a i ≤ 1 0 9 1 \leq a_i \leq 10^9 1ai109
  • 接下来 q q q 行,每行一个询问 [ l , r ] [l, r] [l,r],其中 1 ≤ l ≤ r ≤ n 1 \leq l \leq r \leq n 1lrn

输出格式:

  • 对于每次询问,输出一个整数表示区间 [ l , r ] [l, r] [l,r] 中不同数字的个数。

输入样例

5 2
1 2 1 2 2
1 3
4 5

输出样例

2
1

思路

需要统计数组 a 的区间 [l, r] 中不同元素的种类数目。方法是维护一个 lst 数组,其中 lst[i] 表示 a[i] 上一次出现的位置(若未出现过则为 0),然后统计 lst[l..r] 中值 < l 的元素个数。

  1. lst 数组的定义:

    • 对于 i1n
      • 如果 a[i] 之前未出现过,则 lst[i] = 0
      • 否则,lst[i]a[i] 上一次出现的下标。
  2. 区间 [l, r] 的不同元素种类数目:
    f ( l , r ) = ∑ i = l r [ l s t [ i ] < l ] f(l, r) = \sum_{i=l}^r [lst[i] < l] f(l,r)=i=lr[lst[i]<l]
    其中 [lst[i] < l] 是指示函数,当 lst[i] < l 时为 1,否则为 0

a = [1, 2, 1, 2, 2] 为例:

  1. 计算 lst

    • a[1] = 1:未出现过 → lst[1] = 0
    • a[2] = 2:未出现过 → lst[2] = 0
    • a[3] = 1:上一次出现在 1lst[3] = 1
    • a[4] = 2:上一次出现在 2lst[4] = 2
    • a[5] = 2:上一次出现在 2lst[4] = 4
    • 因此,lst = [0, 0, 1, 2, 4]
  2. 统计区间 [2, 4] 的不同元素数目:

    • lst[2..4] = [0, 1, 2]
    • 统计 < 2 的值:01 → 共 2 个。
    • a[2..4] = [2, 1, 2],不同元素为 12 → 共 2 种。匹配。

代码实现

#include
using namespace std;

const int N=1e6 + 9;
int n, q;
int a[N];
int rt[N], idx;
vector<int> X;

int bin(int x){
    return lower_bound(X.begin(), X.end(), x) - X.begin() + 1;
}

struct node{
    int ls, rs, val;
} t[N * 20];

void insert(int& o, int pre, int val, int s = 0, int e = n) {
    o = ++idx;
    t[o] = t[pre];
    t[o].val++;
    if (s == e)
        return;
    int mid = (s + e) >> 1;
    if (val <= mid)
        insert(t[o].ls, t[pre].ls, val, s, mid);
    else
        insert(t[o].rs, t[pre].rs, val, mid+1, e);
}

int query(int lo, int ro, int l, int r, int s = 0, int e = n) {
    if (l <= s && e <= r)
        return t[ro].val - t[lo].val;
    int mid = (s + e) >> 1, res = 0;
    if (mid >= l)
        res += query(t[lo].ls, t[ro].ls, l, r, s, mid);
    if (mid + 1 <= r)
        res += query(t[lo].rs, t[ro].rs, l, r, mid + 1, e);
    return res;
}

void solve() {
    cin >> n >> q;
    for (int i = 1; i <= n; ++i)
        cin >> a[i];

    // 离散化
    X.resize(n);
    vector<int> lst(n + 1);
    for (int i = 1; i <= n; ++i)
        X[i - 1] = a[i];
    sort(X.begin(), X.end());
    X.erase(unique(X.begin(), X.end()), X.end());

    for (int i = 1; i <= n; ++i) {
        insert(rt[i], rt[i - 1], lst[bin(a[i])]);
        lst[bin(a[i])] = i;
    }


    while (q--) {
        int l, r;
        cin >> l >> r;
        cout << query(rt[l - 1], rt[r], 0, l - 1) << '\n';
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    solve();
    return 0;
}

你可能感兴趣的:(算法,算法,数据结构)