Python数据结构与算法(十四、线段树)

保证一周更两篇吧,以此来督促自己好好的学习!代码的很多地方我都给予了详细的解释,帮助理解。好了,干就完了~加油!
声明:本python数据结构与算法是imooc上liuyubobobo老师java数据结构的python改写,并添加了一些自己的理解和新的东西,liuyubobobo老师真的是一位很棒的老师!超级喜欢他~
如有错误,还请小伙伴们不吝指出,一起学习~
No fears, No distractions.

一、线段树

  1. 什么是线段树?
    线段树所记录的是数组(一般是数组)区间内的信息,比如区间内的元素求和、连乘之类的。
    比如树中的根节点就记录了数组A中从索引0到索引9的元素信息,比如求和神马的。线段树的叶子节点只有一个元素。但不是完全二叉树,是一棵平衡二叉树(树的最大深度至多比最小深度多1)。但是一般的,我们认为线段树是一棵满二叉树(为了能像堆那样用索引来表示左、右孩子- -,只不过堆不一定是满二叉树),只需在不存在元素的地方用None填充就好了。线段树一般是以中点索引来分割左、右子树的。
    Python数据结构与算法(十四、线段树)_第1张图片
  2. 为什么要使用线段树?
    对于有一类问题,我们关心的是线段(或者区间)。
  3. 最经典的线段树问题:
    区间染色,区间查询等操作
  4. 使用线段树的前提:
    无添加,删除元素操作,即区间元素空间已经固定。
  5. 既然线段树是针对数组索引来进行相应的功能的,那么开多大空间合适?
    假设传入的一个数组有n个元素,那么一个线段树要开多少个节点合适呢(注意是用数组表示的满二叉树)?
    我们知道,树的第n - 1层(n从0开始)最多能容纳 2 n − 1 2^{n-1} 2n1个元素,此时树中共有 2 n − 1 2^n-1 2n1个元素(等比数列求和),大约是 2 n 2^n 2n,而最后一层就有 2 n − 1 2^{n-1} 2n1个元素,所以除倒数第一层外其余层的节点个数也为 2 n − 1 2^{n-1} 2n1,这就得出了一个结论:满二叉树的最后一层的节点数与前面层的总数近似相等(实则只差1)。
    如果区间有n个元素,这里做个假设,假设n= 2 k 2^k 2k,那么需要2n的空间(因为处最底层的n个元素,上面也有近似n个元素),如果n一旦比 2 k 2^k 2k多了一个1,那么就得再开一层,为了保证是一棵满二叉树,所以又开了2n的空间,所以按最坏了情况算,我们每次开4n的空间,就肯定能够容纳这区间内的n个元素了。
    结论:区间内有n个元素,则线段树(实则是一个数组)的空间应为:4 * n 这么多个空间。

实现

# -*- coding: utf-8 -*-
# Author:           Annihilation7
# Data:             2018-10-27  01:02 am
# Python version:   3.6
# 啊啊啊,没有拿到1024勋章,好难受~~~

class SegmentTree:
    """线段树类"""
    def __init__(self, alist, merger_):
        """
        Description: 线段树的构造函数
        Params:
        - alist: 用户传入的一个list(这里我们就不用以前实现的Arr类了,直接用python的list啦,如果想用的话也是一点问题都没有的~)
        - func: merge函数,用于对实现两个数合成一个数的功能(比如二元操作符加法、乘法……等等)
        """
        self._data = alist[:]   # 所以为了不改变传入的数组,需要传其副本
        self._tree = [None] * 4 * len(self._data)       # 注意是4倍的存储空间,初始化元素全是None
        # self._tree = [None for i in range(len(self._data) * 4)]
        self._merger = merger_   # merger函数,比如两个元素求和函数……,用lambda表达式比较方便

        self._buildSegmentTree(0, 0, len(self._data)-1) # 调用self._buildSegmentTree来构建线段树

    def getSize(self):
        """
        Description: 获取有效元素的个数
        Returns:
        有效元素个数
        """
        return len(self._data)

    def get(self, index):
        """
        Description: 根据索引index获取相应元素
        时间复杂度:O(1)
        Params:
        - index: 传入的索引
        Returns:
        index索引处的元素值
        """
        if index < 0 or index >= len(self._data):
            raise Exception('Index is illegal!')
        return self._data[index]

    def query(self, quaryL, quaryR):
        """
        Description: 查找[quaryL, quaryR]这个左闭右闭区间上的值(例如对于求和操作就是求这个区间上所有元素的和)
        时间复杂度:O(logn)
        Params:
        - quaryL: 区间左端点的索引
        - quaryR: 区间右端点的索引
        Returns:
        [quaryL, quaryR]区间上的值
        """
        if quaryL < 0 or quaryR < 0 or quaryL >= self.getSize() or quaryR >= self.getSize() or quaryR < quaryL:  # 索引合法性检查
            raise Exception('The indexes is illegal!')
        return self._query(0, 0, self.getSize()-1, quaryL, quaryR)  # 调用self._quary函数

    def set(self, index, e):
        """
        Description: 将数组中index位置的元素设为e,因此此时需要对线段树的内容要进行更新操作(也就是线段树的更新操作)
        时间复杂度:O(logn)
        Params:
        - index: 数组中的索引
        - e: 索引index上元素的新值e
        """
        if index < 0 or index >= self.getSize():
            raise Exception('The index is illegal!')
        self._data[index] = e # 更新self._data
        self._set(0, 0, len(self._data) - 1, index, e)  # 调用self._set函数
        

    def printSegmentTree(self):
        """对线段树进行打印"""
        print('[', end=' ')
        for i in range(len(self._tree)):
            if i == len(self._tree) - 1:
                print(self._tree[i], end=' ]')
                break
            print(self._tree[i], end=',')


    # private
    def _leftChild(self, index):
        """
        Description: 和最大堆一样,由于线段树是一颗完全二叉树,所以可以通过索引的方式找到其左、右孩子的索引(元素从索引0开始盛放)
        Params:
        - index: 输入的索引
        Returns:
        左孩子的索引值
        """
        return 2 * index + 1		# 一定要记住线段树是一棵满树哦,所以用数组就能表示这棵树了,索引关系也和堆是一样的,只不过不需要求父亲节点的索引了

    def _rightChild(self, index):
        """
        Description: 和最大堆一样,由于线段树是一颗完全二叉树,所以可以通过索引的方式找到其左、右孩子的索引(元素从索引0开始盛放)
        Params:
        - index: 输入的索引
        Returns:
        右孩子的索引值
        """
        return 2 * index + 2

    def _buildSegmentTree(self, treeIndex, left, right):
        """
        Description: 以根节点索引为treeIndex,构造self._data索引在[left, right]上的线段树
        Params:
        - treeIndex: 线段树根节点的索引
        - left: 数据左边的索引
        - right: 数据右边的索引
        """
        if left == right:       # 递归到底的情况,left == right,此时只有一个元素
            self._tree[treeIndex] = self._data[left]  # 相应的,self._tree上索引为treeIndex的位置的值置为self._data[left]就好
            return 

        leftChild_index = self._leftChild(treeIndex)    # 获取左孩子的索引
        rightChild_index = self._rightChild(treeIndex)  # 获取右孩子的索引
        
        mid = left + (right - left) // 2        # 获取left和right的中间值,在python中,可以用(left + right) // 2的方式来获得mid,因为不存在数值越界问题
        self._buildSegmentTree(leftChild_index, left, mid)  # 递归向左孩子为根的左子树构建线段树
        self._buildSegmentTree(rightChild_index, mid + 1, right)  # 递归向右孩子为的右子树构建线段树
        self._tree[treeIndex] = self._merger(self._tree[leftChild_index], self._tree[rightChild_index]) # 在回归的过程中,用self._merger函数对两个子节点的值进行merger操作,从而完成整棵树的建立
        
    def _query(self, treeIndex, left, right, quaryL, quaryR):
        """
        Description: 在根节点索引为treeindex的线段树上查找索引范围为[quaryL, quaryR]上的值,其中left, right值代表该节点所表示的索引范围(左闭右闭)
        Params:
        - treeIndex: 根节点所在的索引
        - left: 根节点所代表的区间的左端的索引值(注意是左闭右闭区间哦)
        - right: 根节点所代表的区间的右端点的索引值
        - quaryL: 待查询区间的左端的索引值(也是左闭右闭区间)
        - quaryR: 待查询区间的右端的索引值
        """
        if left == quaryL and right == quaryR:      # 递归到底的情况,区间都对上了,直接返回当前treeIndex索引处的值就好
            return self._tree[treeIndex]            # 返回当前树上索引为treeIndex的元素值
        
        mid = left + (right - left) // 2            # 获取TreeIndex索引处所代表的范围的中点
        leftChild_index = self._leftChild(treeIndex)    # 获取左孩子的索引
        rightChild_index = self._rightChild(treeIndex)  # 获取右孩子的索引

        if quaryL > mid:        # 此时要查询的区间完全位于当前treeIndex所带表的区间的右侧
            return self._query(rightChild_index, mid + 1, right, quaryL, quaryR)    # 直接去右子树找[quaryL, quaryR]
        elif quaryR <= mid:     # 此时要查询的区间完全位于当前treIndex所代表的区间的左侧
            return self._query(leftChild_index, left, mid, quaryL, quaryR)      # 直接去左子树找[quaryL, quaryR]
        
        # 此时一部分在[left, mid]上,一部分在[mid + 1, right]上
        leftResult = self._query(leftChild_index, left, mid, quaryL, mid)   # 在左子树找区间[quaryL, mid]
        rightResult = self._query(rightChild_index, mid + 1, right, mid + 1, quaryR)    # 在右子树找区间[mid + 1, quaryR]
        return self._merger(leftResult, rightResult)        # 最后在回归的过程中两个子节点进行merger操作并返回,得到[quaryL, quaryR]区间上的值

    def _set(self, treeIndex, left, right, index, e):
        """
        Description: 在以索引treeIndex为根节点的线段树中将索引为index的位置的元素设为e(此时treeIndex索引处所代表的区间范围为:[left, right]
        params:
        - treeIndex: 传入的线段树的根节点索引值
        - left: 根节点所代表的区间的左端的索引值
        - right: 根节点所代表的区间的右端点的索引值
        - index: 输入的索引值
        - e: 新的元素值
        """
        if left == right:  # 递归到底的情况,也就是在树中找到了索引为index的元素
            self._tree[treeIndex] = e  # 直接替换
            return

        mid = left + (right - left) // 2        # 找到索引中间值
        leftChild_index = self._leftChild(treeIndex)    # 左孩子索引值
        rightChild_index = self._rightChild(treeIndex)  # 右孩子索引值

        if index <= mid:    # index处于当前treeIndex所代表的区间的左半区
            self._set(leftChild_index, left, mid, index, e) # 到左子树去找index
        else:       # 否则index处于当前treeIndex所代表的区间的右半区
            self._set(rightChild_index, mid + 1, right, index, e)   # 到右子树去找index
        self._tree[treeIndex] = self._merger(self._tree[leftChild_index], self._tree[rightChild_index]) # 由于对树的最底层元素进行了更新操作,因此需要对树的上层也进行一次更新,所以每次回归的都调用merger操作进行上层的值的更新操作

三、测试

from segmenttree import SegmentTree     # 我们的SegmentTree类写在了segmenttree.py文件中

input_list = [-2, 0, 3, -5, 2, -1]
test_st = SegmentTree(input_list, merger_=lambda x, y: x+y)	# 这里以求和为例
test_st.printSegmentTree()
print()
print('索引区间[0, 4]上的元素的和为:', test_st.query(0, 4))
print('将索引为0的元素置为10:')
test_st.set(0, 10)
print('此时索引区间[0, 4]上的元素的和为:', test_st.query(0, 4))

四、输出

[ -3,1,-4,-2,3,-3,-1,-2,0,None,None,-5,2,None,None,None,None,None,None,None,None,None,None,None ]
索引区间[0, 4]上的元素的和为: -2
将索引为0的元素置为10:
此时索引区间[0, 4]上的元素的和为: 10

五、总结

  1. 这个数据结构属于高级数据结构,一般面试不会用到,并且还有很多牛逼的操作我并没有实现(能力不足)。
  2. 线段树适用于无元素添加、删除的数组,并且求解的问题是区间内的问题,用线段树要比常规的O(n)解法快很多(比如求区间内元素的和),毕竟是O(logn)的时间复杂度嘛。
  3. 如果不是比赛神马的,没必要掌握的过于深入。

若有还可以改进、优化的地方,还请小伙伴们批评指正!

你可能感兴趣的:(python数据结构与算法,Python数据结构与算法)