2020/3/19美团点评的最后一道笔试题(算法方向,第五道题)
题意:给定一个n(1到100000),m(1到100000),一个长度为n的01数组(只有0和1),然后m行操作,操作要么是查询当前最长不降序列长度,要么是给定x和y,翻转数组中x到y的一段(即位于[x,y]区间的数,0变成1,1变成0)。
题解:我用的是线段树,实现起来三四百行,有更好的办法欢迎交流。
首先考虑一个01数组计算最长不降序列的办法。由于是01序列,所以不降序列具有形式00001111,即前面为0,后面为1的序列(可以没有1或0)。所以只需要对每个位置,计算出当前位置前面的0的个数和后面的1的个数的和,取最大值就是最长不降序列的长度。
接着考虑翻转,我们可以维护上面提到的和值,然后查询的时候取最大即可。问题是翻转的时候怎么维护。设翻转区间为[x,y],则整个数组分成3段,[0,x - 1], [x,y], [y+1, end]
设zero[i]为位置i前面(包括自己)0的个数,为位置i后面(包括自己)1的个数,我们维护的是sum[i] = zero[i] + one[i],以及最大值
对于[0, x-1]这一段,zero[i]不变,而one[i]的改变都一致,增加的为[x,y]这一段中原来0比1多的个数为 -diff(x, y)(设diff(x,y)为[x,y]中1比0多的个数)
对于[y + 1,end]也一样,one[i]不变,zero[i]增加的 个数为diff(x,y)
所以我们可以用另外一个线段树来维护每一段0的个数(线段树最简单的应用),翻转的时候个数等于区间长度减去当前0的个数,就可以求出diff(x,y)。
而且有了这个线段树,我们也可以在log(n)的时间内求出一个zero[i]或者one[i]
最后我们要维护[x,y]这一段。考虑一个数组整个翻转的情况,比如
0100110,翻转成1011001,则zero[i]和onei都发生了什么改变?
0到位置i共有i+1个位置,所以我们只要知道前面原来有多少个1就是翻转后0的个数,即 n e w Z e r o [ i ] = i + 1 − z e r o [ i ] newZero[i] = i + 1 - zero[i] newZero[i]=i+1−zero[i], 同样 n e w O n e [ i ] = n + 1 − i − o n e [ i ] newOne[i] = n + 1 - i - one[i] newOne[i]=n+1−i−one[i]
现在考虑[x,y]是位于中间,将下标左移到0,先减去[0,x-1]的影响,然后最后加回来,得 n e w Z e r o [ i ] = i − x + 1 − ( z e r o [ i ] − z e r o [ x − 1 ] ) + z e r o [ x − 1 ] = i − x + 1 − z e r o [ i ] + 2 ∗ z e r o [ x − 1 ] newZero[i] = i - x + 1 - (zero[i] - zero[x-1]) + zero[x-1]=i-x+1-zero[i] +2*zero[x-1] newZero[i]=i−x+1−(zero[i]−zero[x−1])+zero[x−1]=i−x+1−zero[i]+2∗zero[x−1]
同样 n e w O n e [ i ] = y − i + 1 − o n e [ i ] + 2 ∗ o n e [ y + 1 ] newOne[i] = y - i + 1 - one[i] + 2*one[y+1] newOne[i]=y−i+1−one[i]+2∗one[y+1]
虽然newZero和newOne都和具体的下标i相关,但是它们的和 n e w S u m = n e w Z e r o + n e w O n e [ i ] = ( y − x ) + 2 − s u m [ i ] + 2 ∗ ( z e r o [ x − 1 ] + o n e [ y + 1 ] ) newSum = newZero + newOne[i] = (y-x) + 2 - sum[i]+ 2 * (zero[x - 1] + one[y + 1]) newSum=newZero+newOne[i]=(y−x)+2−sum[i]+2∗(zero[x−1]+one[y+1])
和下标 i i i无关
将这个过程分为两步,第一步,区间翻转,sum[i]变成-sum[i]。第二,区间增加,增加 y − x + 2 + 2 ∗ ( z e r o [ x − 1 ] + o n e [ y + 1 ] ) y-x+2+2*(zero[x-1] + one[y + 1]) y−x+2+2∗(zero[x−1]+one[y+1])
单个操作或者可交换的操作,这里是两种操作,而且不可交换,好在它们有一定的性质,可以使得他们操作合并
每个线段树的节点维护一个值,最大值,同时有两种延迟操作,延迟加和延迟翻转,延迟加维护一个延迟累计值用于向下更新。这里保证翻转总是先进行的,这样,当向下更新时,可以先翻转。
对于和的更新,每次先将节点向下更新,保证当前及祖先没有翻转延迟操作,这样区间和值就可以延迟操作。对于翻转,我们做同样的工作。
最重要的就是向下更新的问题,我们总保证翻转先进行,所以先对子节点翻转,翻转除了,由于翻转的时候区间和最大值和最小值会调换,所以除了维护最小值还需要维护最大值。翻转除了对最大最小值翻转外,还需要对延迟和值翻转。两次翻转会抵消。对于延迟和向下更新,由于翻转已经做完,直接累计在延迟和值上即可。细节见代码
一开始为了逻辑更清晰用了两颗线段树,实际上由于它们维护的东西不会互相干扰,很容易合并到一起,这里就不实现了。
有暴力解法对比测试,应该没有问题,本地测试注意将N改小
#include
using namespace std;
const int N = 1e5+1;
class SegTreeForNum{
public:
struct Node;
SegTreeForNum(vector<int> &arr):usedNode(0), root(nullptr){
tot = static_cast<int>(arr.size()) - 1;
root = buildTree(arr, 0, tot);
}
int qZero(const int leftPos, const int rightPos){
if(rightPos < leftPos)
return 0;
return query(leftPos, rightPos, 0, tot, root);
}
int diff(const int leftPos, const int rightPos){
return rightPos - leftPos + 1 - 2 * qZero(leftPos, rightPos);
}
int qOne(const int leftPos, const int rightPos){
if(rightPos < leftPos)
return 0;
return rightPos - leftPos + 1 - qZero(leftPos, rightPos);
}
void flip(const int leftPos, const int rightPos){
assert(leftPos <= rightPos);
doflip(leftPos, rightPos, 0, tot, root);
}
private:
void flipNode(Node *p){
p->flip ^= 1;
p->zeroNum = p->len - p->zeroNum;
}
void push(Node *p){
if(!p->flip) return;
flipNode(p->lch);
flipNode(p->rch);
p->flip = false;
p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
}
int query(const int leftPos, const int rightPos, int left, int right, Node *p){
if(left > rightPos) return 0;
if(right < leftPos) return 0;
assert(p != nullptr);
assert(left <= right);
if(leftPos <= left && right <= rightPos){
return p->zeroNum;
}
push(p);
int mid = (left + right) / 2;
return query(leftPos, rightPos, left, mid, p->lch) +
query(leftPos, rightPos, mid + 1, right, p->rch);
}
void doflip(const int leftPos, const int rightPos, int left, int right, Node *p){
if(left > rightPos) return;
if(right < leftPos) return;
assert(p != nullptr);
assert(left <= right);
if(leftPos <= left && right <= rightPos){
flipNode(p);
return;
}
push(p);
int mid = (left + right) / 2;
doflip(leftPos, rightPos, left, mid, p->lch);
doflip(leftPos, rightPos, mid + 1, right, p->rch);
p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
}
Node* buildTree(vector<int> &arr, int left, int right){
Node *p = &nodes[usedNode++];
assert(left <= right);
p->len = right - left + 1;
if(left == right){
p->zeroNum = (arr[left] == 0);
}else{
int mid = (right + left) / 2;
p->lch = buildTree(arr, left, mid);
p->rch = buildTree(arr, mid + 1, right);
p->zeroNum = p->lch->zeroNum + p->rch->zeroNum;
}
return p;
}
struct Node{
Node *lch, *rch;
bool flip;
int zeroNum, len;
Node():lch(nullptr), rch(nullptr), flip(false), zeroNum(0), len(0){}
} nodes[N * 4];
int usedNode, tot;
Node *root;
};//end of SegTreeForNum
class SegTree{
struct Node;
public:
SegTree(vector<int> &arr): tree(arr), usedNode(0){
tot = static_cast<int>(arr.size()) - 1;
vector<int> val(arr.size(), 0);
val[0] = arr[0]^1;
for(size_t i = 1;i < arr.size(); ++i){
val[i] = val[i - 1] + (arr[i]^1);
}
int oneNum = 0;
for(int i = static_cast<int>(arr.size()) - 1; i >= 0; --i){
oneNum += arr[i];
val[i] += oneNum;
}
root = buildTree(val, 0, tot);
}
int getAns(){
return root->maxVal;
}
void flip(int left, int right){
assert(left <= right);
assert(0 <= left);
assert(right <= tot);
int diff = tree.diff(left, right);
add(0, left - 1, -diff, 0, tot, root);
add(right + 1, tot, diff, 0, tot, root);
int leftZeroNum = tree.qZero(0, left - 1);
int rightOneNum = tree.qOne(right + 1, tot);
//y - x + 2 - val + 2 * (lval + rval)
flipupdate(left, right, 0, tot, root);
add(left, right, right - left + 2 + 2 * (leftZeroNum + rightOneNum), 0, tot, root);
tree.flip(left, right);
}
private:
void maintain(Node *p){
p->maxVal = max(p->lch->maxVal, p->rch->maxVal);
p->minVal = min(p->lch->minVal, p->rch->minVal);
}
Node* buildTree(vector<int> &val, int left, int right){
assert(left <= right);
Node *p = &nodes[usedNode++];
if(left == right){
p->minVal = p->maxVal = val[left];
return p;
}
int mid = (left + right) / 2;
p->lch = buildTree(val, left, mid);
p->rch = buildTree(val, mid + 1, right);
maintain(p);
return p;
}
void push(Node *p){
if(p->lch == nullptr){
assert(p->lch == nullptr);
p->lazyFlip = p->lazyAdd = false;
return;
}
if(p->lazyFlip){
flipNode(p->lch);
flipNode(p->rch);
p->lazyFlip = false;
}
if(p->lazyAdd){
p->lch->maxVal += p->lazyNum;
p->lch->minVal += p->lazyNum;
p->lch->lazyAdd = true;
p->lch->lazyNum += p->lazyNum;
p->rch->maxVal += p->lazyNum;
p->rch->minVal += p->lazyNum;
p->rch->lazyAdd = true;
p->rch->lazyNum += p->lazyNum;
p->lazyNum = 0;
p->lazyAdd = false;
}
}
void flipNode(Node *p){
int t = p->minVal;
p->minVal = -p->maxVal;
p->maxVal = -t;
p->lazyFlip ^= 1;
p->lazyNum = -p->lazyNum;
}
void flipupdate(const int leftPos, const int rightPos, int left, int right, Node *p){
if(left > rightPos) return;
if(right < leftPos) return;
push(p);
if(left == right){
flipNode(p);
return;
}
if(leftPos <= left && right <= rightPos){
flipNode(p);
return ;
}
int mid = (left + right) / 2;
flipupdate(leftPos, rightPos, left, mid, p->lch);
flipupdate(leftPos, rightPos, mid + 1, right, p->rch);
maintain(p);
}
void add(const int leftPos, const int rightPos, const int val, int left, int right, Node *p){
if(left > rightPos) return;
if(right < leftPos) return;
push(p);
if(left == right){
p->maxVal += val;
p->minVal += val;
return;
}
if(leftPos <= left && right <= rightPos){
p->lazyAdd = true;
p->lazyNum += val;
p->maxVal += val;
p->minVal += val;
return ;
}
int mid = (left + right) / 2;
add(leftPos, rightPos, val, left, mid, p->lch);
add(leftPos, rightPos, val, mid + 1, right, p->rch);
maintain(p);
}
private:
struct Node{
Node *lch, *rch;
int maxVal, minVal;
int val;
int lazyNum;
bool lazyAdd, lazyFlip;
Node():lch(nullptr), rch(nullptr), lazyNum(0), val(0),minVal(0), maxVal(0), lazyAdd(false), lazyFlip(false){}
} nodes[N * 4];
int tot, usedNode;
Node *root;
SegTreeForNum tree;
};//end of SegTree
#include
#include
#include
//default_random_engine e;
int (*e)() = rand;
vector<int> genData(int n){
vector<int> arr(n, 0);
for(int i = 0;i < n; ++i){
arr[i] = e() & 1;
}
return arr;
}
void testSegTreeForNum(int n = 1000, int q = 1000){
n = max(100, n);
vector<int> arr = genData(n);
SegTreeForNum tree(arr);
while(q--){
int x = e() % n, y = e() % n;
if(x > y) swap(x, y);
if(e()&1){
int ans1 = tree.qZero(x, y);
int ans2 = 0;
for(int i = x; i <= y; ++i)
ans2 += (arr[i]^1);
if(ans1 != ans2){
cerr<<"error"<<q<<endl;
exit(2);
}
}else{
tree.flip(x,y);
for(int i = x; i <= y; ++i)
arr[i] ^= 1;
}
}
}
int bruteforce(vector<int> &arr){
vector<int> val(arr.size(), 0);
val[0] = (arr[0]^1);
for(int i = 1;i < arr.size(); ++i){
val[i] = val[i - 1] + (arr[i] ^ 1);
}
int oNum = 0, res = 0;
for(int i = arr.size() - 1; i>=0 ; --i){
oNum += arr[i];
val[i] += oNum;
res = max(res, val[i]);
}
return res;
}
void testSeg(int n = 1000, int q = 1000){
vector<int> arr = genData(n);
SegTree tree(arr);
while(q--){
int x = e() % n, y = e() % n;
if(x > y) swap(x, y);
if(e()&1){
int ans1 = tree.getAns();
int ans2 = bruteforce(arr);
if(ans1 != ans2){
cerr<<"error"<< " "<<q<<" "<<ans1<<" "<<ans2<<endl;
exit(3);
}
}else{
tree.flip(x,y);
for(int i = x; i <= y; ++i)
arr[i] ^= 1;
}
}
cout<<"n = "<<n<<" test2 done!"<<endl;
}
int main(){
srand(time(NULL) % 1007);
cout<<"test1"<<endl;
for(int i = 0; i < 100; ++i)
testSegTreeForNum(e()%956);
cout<<"test2"<<endl;
for(int i = 0;i < 100; ++i)
testSeg();
cout<<"testdone!"<<endl;
int n, m, x, y;
cin >> n >> m;
vector<int> arr(n, 0);
for(size_t i = 0;i < n; ++i){
cin >> arr[i];
}
SegTree tree(arr);
char op;
while(n--){
cin >> op;
if(op == 'Q'){
cout<<tree.getAns()<<endl;
}else{
cin>>x>>y;
tree.flip(x, y);
}
}
return 0;
}