可以动态翻转和查询的 01最长不降序列

2020/3/19美团点评的最后一道笔试题(算法方向,第五道题)

题意:给定一个n(1到100000),m(1到100000),一个长度为n的01数组(只有0和1),然后m行操作,操作要么是查询当前最长不降序列长度,要么是给定x和y,翻转数组中x到y的一段(即位于[x,y]区间的数,0变成1,1变成0)。

题解:我用的是线段树,实现起来三四百行,有更好的办法欢迎交流。

首先考虑一个01数组计算最长不降序列的办法。由于是01序列,所以不降序列具有形式00001111,即前面为0,后面为1的序列(可以没有1或0)。所以只需要对每个位置,计算出当前位置前面的0的个数和后面的1的个数的和,取最大值就是最长不降序列的长度。

接着考虑翻转,我们可以维护上面提到的和值,然后查询的时候取最大即可。问题是翻转的时候怎么维护。设翻转区间为[x,y],则整个数组分成3段,[0,x - 1], [x,y], [y+1, end]
设zero[i]为位置i前面(包括自己)0的个数,为位置i后面(包括自己)1的个数,我们维护的是sum[i] = zero[i] + one[i],以及最大值

对于[0, x-1]这一段,zero[i]不变,而one[i]的改变都一致,增加的为[x,y]这一段中原来0比1多的个数为 -diff(x, y)(设diff(x,y)为[x,y]中1比0多的个数)
对于[y + 1,end]也一样,one[i]不变,zero[i]增加的 个数为diff(x,y)

所以我们可以用另外一个线段树来维护每一段0的个数(线段树最简单的应用),翻转的时候个数等于区间长度减去当前0的个数,就可以求出diff(x,y)。
而且有了这个线段树,我们也可以在log(n)的时间内求出一个zero[i]或者one[i]

最后我们要维护[x,y]这一段。考虑一个数组整个翻转的情况,比如
0100110,翻转成1011001,则zero[i]和onei都发生了什么改变?
0到位置i共有i+1个位置,所以我们只要知道前面原来有多少个1就是翻转后0的个数,即 n e w Z e r o [ i ] = i + 1 − z e r o [ i ] newZero[i] = i + 1 - zero[i] newZero[i]=i+1zero[i], 同样 n e w O n e [ i ] = n + 1 − i − o n e [ i ] newOne[i] = n + 1 - i - one[i] newOne[i]=n+1ione[i]
现在考虑[x,y]是位于中间,将下标左移到0,先减去[0,x-1]的影响,然后最后加回来,得 n e w Z e r o [ i ] = i − x + 1 − ( z e r o [ i ] − z e r o [ x − 1 ] ) + z e r o [ x − 1 ] = i − x + 1 − z e r o [ i ] + 2 ∗ z e r o [ x − 1 ] newZero[i] = i - x + 1 - (zero[i] - zero[x-1]) + zero[x-1]=i-x+1-zero[i] +2*zero[x-1] newZero[i]=ix+1(zero[i]zero[x1])+zero[x1]=ix+1zero[i]+2zero[x1]
同样 n e w O n e [ i ] = y − i + 1 − o n e [ i ] + 2 ∗ o n e [ y + 1 ] newOne[i] = y - i + 1 - one[i] + 2*one[y+1] newOne[i]=yi+1one[i]+2one[y+1]

虽然newZero和newOne都和具体的下标i相关,但是它们的和 n e w S u m = n e w Z e r o + n e w O n e [ i ] = ( y − x ) + 2 − s u m [ i ] + 2 ∗ ( z e r o [ x − 1 ] + o n e [ y + 1 ] ) newSum = newZero + newOne[i] = (y-x) + 2 - sum[i]+ 2 * (zero[x - 1] + one[y + 1]) newSum=newZero+newOne[i]=(yx)+2sum[i]+2(zero[x1]+one[y+1])
和下标 i i i无关

将这个过程分为两步,第一步,区间翻转,sum[i]变成-sum[i]。第二,区间增加,增加 y − x + 2 + 2 ∗ ( z e r o [ x − 1 ] + o n e [ y + 1 ] ) y-x+2+2*(zero[x-1] + one[y + 1]) yx+2+2(zero[x1]+one[y+1])

单个操作或者可交换的操作,这里是两种操作,而且不可交换,好在它们有一定的性质,可以使得他们操作合并
每个线段树的节点维护一个值,最大值,同时有两种延迟操作,延迟加和延迟翻转,延迟加维护一个延迟累计值用于向下更新。这里保证翻转总是先进行的,这样,当向下更新时,可以先翻转。
对于和的更新,每次先将节点向下更新,保证当前及祖先没有翻转延迟操作,这样区间和值就可以延迟操作。对于翻转,我们做同样的工作。
最重要的就是向下更新的问题,我们总保证翻转先进行,所以先对子节点翻转,翻转除了,由于翻转的时候区间和最大值和最小值会调换,所以除了维护最小值还需要维护最大值。翻转除了对最大最小值翻转外,还需要对延迟和值翻转。两次翻转会抵消。对于延迟和向下更新,由于翻转已经做完,直接累计在延迟和值上即可。细节见代码
一开始为了逻辑更清晰用了两颗线段树,实际上由于它们维护的东西不会互相干扰,很容易合并到一起,这里就不实现了。
有暴力解法对比测试,应该没有问题,本地测试注意将N改小

#include 

using namespace std;
const int N = 1e5+1;

class SegTreeForNum{
public:
    struct Node; 
    SegTreeForNum(vector<int> &arr):usedNode(0), root(nullptr){
	    tot = static_cast<int>(arr.size()) - 1; 
	    root = buildTree(arr, 0, tot);
    }

    int qZero(const int leftPos, const int rightPos){
	    if(rightPos < leftPos) 
		    return 0;
	    return query(leftPos, rightPos, 0, tot, root);	
    }
    int diff(const int leftPos, const int rightPos){
	    return rightPos - leftPos + 1 - 2 * qZero(leftPos, rightPos); 
    }

    int qOne(const int leftPos, const int rightPos){
	    if(rightPos < leftPos) 
		    return 0;
	    return rightPos - leftPos + 1 - qZero(leftPos, rightPos);
    }

    void flip(const int leftPos, const int rightPos){
	    assert(leftPos <= rightPos);
	    doflip(leftPos, rightPos, 0, tot, root);
    } 

private:
    void flipNode(Node *p){
	    p->flip ^= 1;
	    p->zeroNum = p->len - p->zeroNum;
    }

    void push(Node *p){
	    if(!p->flip) return;
	    flipNode(p->lch);
	    flipNode(p->rch);
		p->flip = false;
		p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
    }

    int  query(const int leftPos, const int rightPos, int left, int right, Node *p){
	    if(left > rightPos) return 0;
	    if(right < leftPos) return 0;
	    assert(p != nullptr);
	    assert(left <= right);
	    if(leftPos <= left && right <= rightPos){
		    return p->zeroNum;
	    }
	    push(p);
	    int mid = (left + right) / 2;
	    return query(leftPos, rightPos, left, mid, p->lch) + 
		    query(leftPos, rightPos, mid + 1, right, p->rch);
    }


    void doflip(const int leftPos, const int rightPos, int left, int right, Node *p){
	    if(left > rightPos) return;
	    if(right < leftPos) return;
	    assert(p != nullptr);
	    assert(left <= right);

	    if(leftPos <= left && right <= rightPos){
		    flipNode(p);
		    return;
	    }

	    push(p);
	    int mid = (left + right) / 2;
	    doflip(leftPos, rightPos, left, mid, p->lch);
	    doflip(leftPos, rightPos, mid + 1, right, p->rch);
	    p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
    } 

    Node* buildTree(vector<int> &arr, int left, int right){
	    Node *p = &nodes[usedNode++]; 
	    assert(left <= right);
	    p->len = right - left + 1;
	    if(left == right){
		    p->zeroNum = (arr[left] == 0);
	    }else{
		    int mid = (right + left) / 2;
		    p->lch = buildTree(arr, left, mid);
		    p->rch = buildTree(arr, mid + 1, right);
		    p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
	    }
	    return p;
    }

    struct Node{
	    Node *lch, *rch;
	    bool flip;
	    int zeroNum, len;
	    Node():lch(nullptr), rch(nullptr), flip(false), zeroNum(0), len(0){}
    } nodes[N * 4];

    int usedNode, tot;
    Node *root;
};//end of SegTreeForNum


class SegTree{
	struct Node;
	public:

		SegTree(vector<int> &arr): tree(arr), usedNode(0){
    	tot = static_cast<int>(arr.size()) - 1;
        vector<int> val(arr.size(), 0);
        
        val[0] = arr[0]^1;
        for(size_t i = 1;i < arr.size(); ++i){
            val[i] = val[i - 1] + (arr[i]^1);
			
        }
       
       int oneNum = 0; 
       
       for(int i = static_cast<int>(arr.size()) - 1; i >= 0; --i){
            oneNum += arr[i];
            val[i] += oneNum;
       } 
        
       root = buildTree(val, 0, tot);

    }

    int getAns(){
        return root->maxVal;
    }
    
    void flip(int left, int right){
        assert(left <= right);
        assert(0 <= left);
        assert(right <= tot);
       
        int diff = tree.diff(left, right);
        add(0, left - 1, -diff, 0, tot, root);
        add(right + 1, tot, diff, 0, tot, root);
        int leftZeroNum = tree.qZero(0, left - 1);
        int rightOneNum = tree.qOne(right + 1, tot);
        //y - x + 2 - val + 2 * (lval + rval) 
        flipupdate(left, right, 0, tot, root);
        add(left, right, right - left + 2 + 2 * (leftZeroNum + rightOneNum), 0, tot, root);
        tree.flip(left, right);
    }
private:
    void maintain(Node *p){
        p->maxVal = max(p->lch->maxVal, p->rch->maxVal);
        p->minVal = min(p->lch->minVal, p->rch->minVal);
    }
    
    Node* buildTree(vector<int> &val, int left, int right){
        assert(left <= right);

        Node *p = &nodes[usedNode++]; 
 	    if(left == right){ 
	        p->minVal = p->maxVal = val[left];
	        return p;
	    }

	    int mid = (left + right) / 2;
	    p->lch = buildTree(val, left, mid);
	    p->rch = buildTree(val, mid + 1, right);
        maintain(p);
	    return p;
    }

    void push(Node *p){
        if(p->lch == nullptr){
            assert(p->lch == nullptr);
            p->lazyFlip = p->lazyAdd = false; 
            return;
        }

        if(p->lazyFlip){
            flipNode(p->lch);
            flipNode(p->rch);
            p->lazyFlip = false;
        }
        
        if(p->lazyAdd){
            p->lch->maxVal += p->lazyNum;
            p->lch->minVal += p->lazyNum;
            p->lch->lazyAdd = true;
            p->lch->lazyNum += p->lazyNum;
            p->rch->maxVal += p->lazyNum;
            p->rch->minVal += p->lazyNum;
            p->rch->lazyAdd = true;
            p->rch->lazyNum += p->lazyNum;
            p->lazyNum = 0;
            p->lazyAdd = false;
        }    
    }
    
    void flipNode(Node *p){
        int t = p->minVal;
        p->minVal = -p->maxVal;
        p->maxVal = -t; 
        p->lazyFlip ^= 1;
        p->lazyNum = -p->lazyNum;
    }
    
    void flipupdate(const int leftPos, const int rightPos, int left, int right, Node *p){

	    if(left > rightPos) return;
	    if(right < leftPos) return;	
	    push(p);
        if(left == right){
            flipNode(p);
            return;
        } 

    	if(leftPos <= left && right <= rightPos){
            flipNode(p);
	        return ;
        }
        
        int mid = (left + right) / 2;

        flipupdate(leftPos, rightPos, left, mid, p->lch);
        flipupdate(leftPos, rightPos, mid + 1, right, p->rch);
        maintain(p);

    }
    void add(const int leftPos, const int rightPos, const int val, int left, int right, Node *p){
	    if(left > rightPos) return;
	    if(right < leftPos) return;	
	    push(p);
	    if(left == right){
            p->maxVal += val;
            p->minVal += val;
            return;
    	}
	
    	if(leftPos <= left && right <= rightPos){
    	    p->lazyAdd = true;
            p->lazyNum += val;
            p->maxVal += val; 
            p->minVal += val;
	        return ;
	    }
	    int mid = (left + right) / 2;
	    add(leftPos, rightPos, val, left, mid, p->lch);
	    add(leftPos, rightPos, val, mid + 1, right, p->rch);
        maintain(p);
    }


private:
    struct Node{
	    Node *lch, *rch;
	    int  maxVal, minVal;
        int  val;
		int lazyNum;
	    bool lazyAdd, lazyFlip;
	    Node():lch(nullptr), rch(nullptr), lazyNum(0), val(0),minVal(0), maxVal(0), lazyAdd(false), lazyFlip(false){}
	    
    } nodes[N * 4];
    int tot, usedNode;
    Node *root;
    SegTreeForNum tree;
};//end of SegTree

#include 
#include 
#include 
//default_random_engine e;
int (*e)() = rand;
vector<int> genData(int n){
  vector<int> arr(n, 0);
  for(int i = 0;i < n; ++i){
       arr[i] = e() & 1;
  }
  return arr;
}

void testSegTreeForNum(int n = 1000, int q = 1000){
  n = max(100, n);
  vector<int> arr = genData(n);
  SegTreeForNum tree(arr);
  while(q--){
      int x = e() % n, y = e() % n;
      if(x > y) swap(x, y);
      if(e()&1){
        int ans1 = tree.qZero(x, y);
        int ans2 = 0;
        for(int i = x; i <= y; ++i)
             ans2 += (arr[i]^1);
        if(ans1 != ans2){
            cerr<<"error"<<q<<endl;
            exit(2);
		}
      }else{
          tree.flip(x,y);
          for(int i = x; i <= y; ++i)
               arr[i] ^= 1;
      }
  }
}
int bruteforce(vector<int> &arr){
    vector<int> val(arr.size(), 0);

    val[0] = (arr[0]^1);
    for(int i = 1;i < arr.size(); ++i){
        val[i] = val[i - 1] + (arr[i] ^ 1);
    }
    int oNum = 0, res = 0;
    for(int i = arr.size() - 1; i>=0 ; --i){
        oNum += arr[i];
        val[i] += oNum;
        res = max(res, val[i]);
    }

    return res;
}
void testSeg(int n = 1000, int q = 1000){
    vector<int> arr = genData(n);
    
    SegTree tree(arr);
    while(q--){
        int x = e() % n, y = e() % n;
        if(x > y) swap(x, y);
        if(e()&1){
            int ans1 = tree.getAns();
            int ans2 = bruteforce(arr);

            if(ans1 != ans2){
                cerr<<"error"<< " "<<q<<" "<<ans1<<" "<<ans2<<endl;
                exit(3);
            }
        }else{
            tree.flip(x,y);
            for(int i = x; i <= y; ++i)
               arr[i] ^= 1;
        }
    }
	cout<<"n = "<<n<<" test2 done!"<<endl;
}

int main(){

    srand(time(NULL) % 1007);

    cout<<"test1"<<endl;
	for(int i = 0; i < 100; ++i)
    	testSegTreeForNum(e()%956);
    cout<<"test2"<<endl;
	for(int i = 0;i < 100; ++i)
    	testSeg();
    cout<<"testdone!"<<endl;

    int n, m, x, y;
    cin >> n >> m;
    vector<int> arr(n, 0);
    for(size_t i = 0;i < n; ++i){
        cin >> arr[i];
    }

    SegTree tree(arr);
    char op;
    while(n--){
        cin >> op;
        if(op == 'Q'){
            cout<<tree.getAns()<<endl;
        }else{
            cin>>x>>y;
            tree.flip(x, y);
        }
    }

    return 0;	
}

你可能感兴趣的:(算法和实现)