字典树(Trie) 理论知识复习及精选例题解析

字典树 理论知识复习及精选例题解析

  • 一、字典树理论知识
  • 二、精选例题解析
    • 例题 1. P8306【模板】字典树
    • 例题 2. P2580 于是他错误的点名开始了
    • 例题 3. P10471 最大异或对 The XOR Largest Pair
  • 三、字典树的使用思路和细节
    • 使用思路
    • 细节注意
  • 四、总结

一、字典树理论知识

1. 定义

字典树( T r i e ) 字典树(Trie) 字典树(Trie,又称前缀树,是一种树形数据结构,用于 高效地存储和检索字符串集合。它的每个节点代表一个字符,从根节点到某一节点的路径上经过的字符连接起来,即为该节点对应的字符串。

2. 特点

  • 空间换时间: 通过牺牲一定的空间来换取快速的字符串查找和插入操作。

  • 前缀匹配高效: 可以在 O ( m ) O(m) O(m) 的时间复杂度内完成字符串的插入、查找和删除操作,其中 m 是字符串的长度。

3. 基本操作

  • 插入(Insert): 将一个字符串插入到字典树中。

  • 查找(Query): 查询一个字符串是否存在于字典树中,或者查询以某个字符串为前缀的字符串数量。

二、精选例题解析

例题 1. P8306【模板】字典树

P8306【模板】字典树

题目描述

给定 T 组数据,每组数据包含 n 个字符串和 q 次查询,每次查询要求输出以某个字符串为前缀的字符串数量。

思路讲解

  • 构建字典树时,使用一个 p a s s pass pass 数组来记录每个节点被经过的次数。

  • 插入字符串时,从根节点开始,依次遍历字符串的每个字符,若该字符对应的路径不存在,则创建新路径,并将经过的节点的 p a s s pass pass 值加一。

  • 查询时,从根节点开始,依次遍历查询字符串的每个字符,若路径不存在,则返回 0 0 0,否则返回查询字符串最后一个字符所在节点的 p a s s pass pass 值。


AC代码

C++ 代码

#include 
#include 
#include 

const int N = 3e6 + 10;
int T, n, q, idx;
int trie[N][26 + 26 + 10], pass[N]; //只需维护一个 pass 数组,用来操作前缀字符串

inline int get_num(char ch) { //获取字符的 ASCII 码值
    if(ch >= 'a' && ch <= 'z') {
        return ch - 'a';
    } else if(ch >= 'A' && ch <= 'Z') {
        return ch - 'A' + 26;
    } else { //这块映射 ASCII 码值时需要加上 26 或 52,避免将多个字符映射到 0 ~ 25 上的相同位置
        return ch - '0' + 26 + 26;
    }
}

inline void insert(std::string& str) {
    int cur = 0;
    pass[cur]++;//根节点的 pass 加一

    for(char ch : str) {
        int path = get_num(ch);

        if(trie[cur][path] == 0) { //没有现成的路径,创建新路径
            trie[cur][path] = ++idx;
        }

        cur = trie[cur][path]; //找到下一个节点
        pass[cur]++; //途经的节点的 pass 值加一
    }
}

inline int query(std::string& str) {
    int cur = 0;

    for(char ch : str) {
        int path = get_num(ch);

        if(trie[cur][path] == 0) { //该字符串不存在于字典树中
            return 0;
        } else { //找到下一个节点
            cur = trie[cur][path];
        }
    }
    return pass[cur]; //str 为前缀串的数量,返回的是 str 串的最后一个字符的节点处的 pass 值
}

int main(void) {
    std::ios::sync_with_stdio(false);
    std::cin.tie(0); std::cout.tie(0);

    std::cin >> T;

    while(T--) {
        std::cin >> n >> q;

        std::string str;

        //n 次插入字符串
        while(n--) {
            std::cin >> str;
            insert(str);
        }

        //q 次查询前缀数
        while(q--) {
            std::cin >> str;
            std::cout << query(str) << std::endl;
        }

        //重置数组
        for(int i = 0; i <= idx; ++i) { //要从 0 开始到 idx,因为 0 位置代表第一个字符的选取,也需要重置
            for(int j = 0; j < 62; ++j) {
                trie[i][j] = 0;
            }
            pass[i] = 0;
        }
        idx = 0;
    }
    return 0;
}

Python 代码

T = int(input())

for _ in range(T):
    n, q = map(int, input().split())
    trie = {}
    pass_count = {}

    def get_num(ch):
        if 'a' <= ch <= 'z':
            return ord(ch) - ord('a')
        elif 'A' <= ch <= 'Z':
            return ord(ch) - ord('A') + 26
        else:
            return ord(ch) - ord('0') + 26 + 26

    def insert(str):
        cur = 0
        if cur not in pass_count:
            pass_count[cur] = 0
        pass_count[cur] += 1

        for ch in str:
            path = get_num(ch)
            if cur not in trie:
                trie[cur] = {}
            if path not in trie[cur]:
                trie[cur][path] = len(trie) + 1
            cur = trie[cur][path]
            if cur not in pass_count:
                pass_count[cur] = 0
            pass_count[cur] += 1

    def query(str):
        cur = 0
        for ch in str:
            path = get_num(ch)
            if cur not in trie or path not in trie[cur]:
                return 0
            cur = trie[cur][path]
        return pass_count.get(cur, 0)

    # n 次插入字符串
    for _ in range(n):
        str = input()
        insert(str)

    # q 次查询前缀数
    for _ in range(q):
        str = input()
        print(query(str))

    # 重置
    trie = {}
    pass_count = {}

Java 代码

import java.util.*;

public class Main {
    static final int N = 3000010;
    static int T, n, q, idx;
    static int[][] trie = new int[N][62];
    static int[] pass = new int[N];

    static int get_num(char ch) {
        if (ch >= 'a' && ch <= 'z') {
            return ch - 'a';
        } else if (ch >= 'A' && ch <= 'Z') {
            return ch - 'A' + 26;
        } else {
            return ch - '0' + 26 + 26;
        }
    }

    static void insert(String str) {
        int cur = 0;
        pass[cur]++;
        for (char ch : str.toCharArray()) {
            int path = get_num(ch);
            if (trie[cur][path] == 0) {
                trie[cur][path] = ++idx;
            }
            cur = trie[cur][path];
            pass[cur]++;
        }
    }

    static int query(String str) {
        int cur = 0;
        for (char ch : str.toCharArray()) {
            int path = get_num(ch);
            if (trie[cur][path] == 0) {
                return 0;
            }
            cur = trie[cur][path];
        }
        return pass[cur];
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        T = scanner.nextInt();
        while (T-- > 0) {
            n = scanner.nextInt();
            q = scanner.nextInt();
            scanner.nextLine();

            // n 次插入字符串
            for (int i = 0; i < n; i++) {
                String str = scanner.nextLine();
                insert(str);
            }

            // q 次查询前缀数
            for (int i = 0; i < q; i++) {
                String str = scanner.nextLine();
                System.out.println(query(str));
            }

            // 重置数组
            for (int i = 0; i <= idx; i++) {
                Arrays.fill(trie[i], 0);
                pass[i] = 0;
            }
            idx = 0;
        }
        scanner.close();
    }
}

例题 2. P2580 于是他错误的点名开始了

P2580 于是他错误的点名开始了

题目描述

给定 n n n 个学生的名字,然后进行 m m m 次点名,每次点名需要判断该名字是否存在,若存在且是第一次被点到,则输出 OK,若不存在则输出 WRONG,若已经被点到过则输出 REPEAT

思路讲解

  • 构建字典树时,使用一个 e n d end end 数组来记录每个字符串的结束情况。

  • 插入字符串时,从根节点开始,依次遍历字符串的每个字符,若该字符对应的路径不存在,则创建新路径,最后将字符串结束节点的 e n d end end 值加一。

  • 查询时,从根节点开始,依次遍历查询字符串的每个字符,若路径不存在,则返回 0 0 0,否则返回查询字符串最后一个字符所在节点的 e n d end end 值,并将该节点的 e n d end end 值置为 − 1 -1 1,表示该字符串已经被查询过。


AC代码

C++ 代码

#include 
#include 

const int N = 5e5 + 10;
int n, m, idx;
int trie[N][26], end[N]; //只需维护一个 end 数组,用来操作完整的字符串

inline void insert(std::string& str) {
    int cur = 0;

    for(char ch : str) {
        int path = ch - 'a';

        if(trie[cur][path] == 0) {
            trie[cur][path] = ++idx;
        }

        cur = trie[cur][path];
    }
    end[cur]++;
}

inline int query(std::string& str) {
    int cur = 0;

    for(char ch : str) {
        int path = ch - 'a';

        if(trie[cur][path] == 0) {
            return 0;
        } else {
            cur = trie[cur][path];
        }
    }
    int ret = end[cur]; //保存结果
    end[cur] = -1; //第一次查询后将该位置置为 -1,便于判断是否重复查询

    return ret;
}

int main(void) {
    std::cin >> n;

    std::string str;

    //插入 n 个字符串
    while(n--) {
        std::cin >> str;
        insert(str);
    }

    std::cin >> m;

    //查询 m 次字符串
    while(m--) {
        std::cin >> str;

        int val = query(str); //按返回值判断结果

        if(val > 0) { //第一次查询到
            std::cout << "OK" << std::endl;
        } else if(val == 0) { //未查询到
            std::cout << "WRONG" << std::endl;
        } else { //重复查询
            std::cout << "REPEAT" << std::endl;
        }
    }
    return 0;
}

Python 代码

n = int(input())
trie = {}
end = {}

def insert(str):
    cur = 0
    for ch in str:
        path = ord(ch) - ord('a')
        if cur not in trie:
            trie[cur] = {}
        if path not in trie[cur]:
            trie[cur][path] = len(trie) + 1
        cur = trie[cur][path]
    if cur not in end:
        end[cur] = 0
    end[cur] += 1

def query(str):
    cur = 0
    for ch in str:
        path = ord(ch) - ord('a')
        if cur not in trie or path not in trie[cur]:
            return 0
        cur = trie[cur][path]
    ret = end.get(cur, 0)
    end[cur] = -1
    return ret

# 插入 n 个字符串
for _ in range(n):
    str = input()
    insert(str)

m = int(input())
# 查询 m 次字符串
for _ in range(m):
    str = input()
    val = query(str)
    if val > 0:
        print("OK")
    elif val == 0:
        print("WRONG")
    else:
        print("REPEAT")

Java 代码

import java.util.*;

public class Main {
    static final int N = 500010;
    static int n, m, idx;
    static int[][] trie = new int[N][26];
    static int[] end = new int[N];

    static void insert(String str) {
        int cur = 0;
        for (char ch : str.toCharArray()) {
            int path = ch - 'a';
            if (trie[cur][path] == 0) {
                trie[cur][path] = ++idx;
            }
            cur = trie[cur][path];
        }
        end[cur]++;
    }

    static int query(String str) {
        int cur = 0;
        for (char ch : str.toCharArray()) {
            int path = ch - 'a';
            if (trie[cur][path] == 0) {
                return 0;
            }
            cur = trie[cur][path];
        }
        int ret = end[cur];
        end[cur] = -1;
        return ret;
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        scanner.nextLine();

        // 插入 n 个字符串
        for (int i = 0; i < n; i++) {
            String str = scanner.nextLine();
            insert(str);
        }

        m = scanner.nextInt();
        scanner.nextLine();

        // 查询 m 次字符串
        for (int i = 0; i < m; i++) {
            String str = scanner.nextLine();
            int val = query(str);
            if (val > 0) {
                System.out.println("OK");
            } else if (val == 0) {
                System.out.println("WRONG");
            } else {
                System.out.println("REPEAT");
            }
        }
        scanner.close();
    }
}

例题 3. P10471 最大异或对 The XOR Largest Pair

P10471 最大异或对 The XOR Largest Pair

题目描述

给定一个长度为 n n n 的整数序列,要求找出序列中两个数的异或值的最大值。

思路讲解

  • 将每个整数的二进制表示插入到字典树中。

  • 对于每个整数,从最高位开始,在字典树中尝试寻找与该位相反的位,若存在则将异或结果的该位置为 1 1 1,否则保持为 0 0 0


AC代码

C++ 代码

#include 

const int N = 1e5 + 10;
int n, idx;
int trie[N * 32][2], a[N]; //因为保存的是二进制位,故每个数字都要有 32 (4 bytes * 8 bits) 个位置

void insert(int x) {
    int cur = 0;

    //将 x 的二进制表示插入字典树中
    for(int i = 31; i >= 0; --i) {
        int path = (x >> i) & 1;

        if(trie[cur][path] == 0) {
            trie[cur][path] = ++idx;
        }

        cur = trie[cur][path];
    }
}

int find(int x) {
    int cur = 0, ret = 0;

    //从最高位按位查找
    for(int i = 31; i >= 0; --i) {
        int path = (x >> i) & 1;

        if(trie[cur][path ^ 1] != 0) { //当前位的值为 path,尝试寻找 path ^ 1 之后的结果,也就是 0 -> 1, 1 -> 0,因为这样就能得到异或后的该位为 1
            ret = ret | (1 << i); //将该位置为 1
            cur = trie[cur][path ^ 1]; //找到下一个节点
        } else {
            cur = trie[cur][path]; //没有相反位,则保持该位为 0,继续找下一个节点
        }
    }
    return ret;
}

int main(void) {
    std::cin >> n;

    //读取序列保存,并插入字典树中
    for(int i = 1; i <= n; ++i) {
        std::cin >> a[i];
        insert(a[i]);
    }

    //查找最大异或对
    int maxi = 0;
    for(int i = 1; i <= n; ++i) {
        maxi = std::max(maxi, find(a[i]));
    }
    std::cout << maxi << std::endl;

    return 0;
}

Python 代码

n = int(input())
trie = {}
a = []

def insert(x):
    cur = 0
    for i in range(31, -1, -1):
        path = (x >> i) & 1
        if cur not in trie:
            trie[cur] = {}
        if path not in trie[cur]:
            trie[cur][path] = len(trie) + 1
        cur = trie[cur][path]

def find(x):
    cur = 0
    ret = 0
    for i in range(31, -1, -1):
        path = (x >> i) & 1
        if cur in trie and (path ^ 1) in trie[cur]:
            ret |= (1 << i)
            cur = trie[cur][path ^ 1]
        else:
            cur = trie[cur].get(path, 0)
    return ret

# 读取序列保存,并插入字典树中
for _ in range(n):
    x = int(input())
    a.append(x)
    insert(x)

# 查找最大异或对
maxi = 0
for x in a:
    maxi = max(maxi, find(x))

print(maxi)

Java 代码

import java.util.*;

public class Main {
    static final int N = 100010;
    static int n, idx;
    static int[][] trie = new int[N * 32][2];
    static int[] a = new int[N];

    static void insert(int x) {
        int cur = 0;
        for (int i = 31; i >= 0; i--) {
            int path = (x >> i) & 1;
            if (trie[cur][path] == 0) {
                trie[cur][path] = ++idx;
            }
            cur = trie[cur][path];
        }
    }

    static int find(int x) {
        int cur = 0, ret = 0;
        for (int i = 31; i >= 0; i--) {
            int path = (x >> i) & 1;
            if (trie[cur][path ^ 1] != 0) {
                ret |= (1 << i);
                cur = trie[cur][path ^ 1];
            } else {
                cur = trie[cur][path];
            }
        }
        return ret;
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();

        // 读取序列保存,并插入字典树中
        for (int i = 1; i <= n; i++) {
            a[i] = scanner.nextInt();
            insert(a[i]);
        }

        // 查找最大异或对
        int maxi = 0;
        for (int i = 1; i <= n; i++) {
            maxi = Math.max(maxi, find(a[i]));
        }
        System.out.println(maxi);
        scanner.close();
    }
}

三、字典树的使用思路和细节

使用思路

  1. 确定字符集(String): 根据题目中字符串的字符范围,确定字典树每个节点的子节点数量。

  2. 插入操作(Insert): 从根节点开始,依次遍历字符串的每个字符,若该字符对应的路径不存在,则创建新路径,最后标记字符串的结束。

  3. 查找操作(Query): 从根节点开始,依次遍历查询字符串的每个字符,若路径不存在,则查询失败,否则继续查找,直到找到字符串的结束标记或满足查询条件。

细节注意

  • 字符映射: 对于不同的字符集,需要将字符映射到合适的索引位置,避免冲突。

  • 空间复杂度: 字典树的空间复杂度较高,需要根据实际情况合理分配内存。

  • 重置操作: 在多组数据的情况下,需要及时重置字典树的数组,避免影响后续操作。同时也要注意第一题模板题的重置操作,数组所有位置重置会超时,需要根据上一次使用的范围 [ 0 , i d x ] [0, idx] [0,idx] 部分重置。

四、总结

字典树 ( T r i e ) 字典树(Trie) 字典树(Trie) 是一种非常实用的数据结构,特别适用于 字符串的存储和检索。通过上述例题可以看出,字典树可以高效地解决 前缀匹配、字符串查找和最大异或对等问题。在使用字典树时,需要根据具体问题选择合适的数组来维护节点信息,如 p a s s pass pass 数组用于前缀计数, e n d end end 数组用于标记字符串的结束。同时,要 注意字符映射和空间复杂度的问题,以确保算法的正确性和效率。


博客中的 数据结构和算法模板以及算法题 的全部 C++ 代码 和部分 PythonJava 代码实现都在作者的 Github 仓库中能找到,后续会补充 PythonJava 实现
感谢阅读

你可能感兴趣的:(算法,java,c++,数据结构,python,leetcode,vscode)