\(n\times m\) 的网格中,在第 \(i\) 行 \(j\) 列有 \(a[i][j]\) 个泡泡,每次可以收割一行或一列的泡泡,最多收割 \(4\) 次,问最多可以收割到多少泡泡。\(nm \leq 10^5\)
Solution
讨论答案的各种情况
- 四行,这种情况下直接求和取前 \(4\) 个最大值即可
- 三行一列,枚举取哪一列,然后每次暴力提取前 \(3\) 个行最大值
- 两行两列,显然 \(n,m\) 中必有一个 \(\leq \sqrt{10^5}\),设它是行,则暴力枚举选哪两行,然后仍然按照前述方法计算答案即可
其余情况可以由上面三种基本情况旋转得到
复杂度 \(\mathcal{O} (nm \min(n,m))\)
#include
using namespace std;
#define int long long
int n,m,**a,**b;
int solve1a() {
int ans=0;
vector v;
for(int i=1;i<=n;i++) {
int sum=0;
for(int j=1;j<=m;j++) sum+=a[i][j];
v.push_back(sum);
}
sort(v.begin(),v.end());
for(int i=0;i<4;i++) {
if(v.size()) ans+=v.back(), v.pop_back();
}
return ans;
}
int solve1b() {
int ans=0;
vector v;
for(int i=1;i<=m;i++) {
int sum=0;
for(int j=1;j<=n;j++) sum+=a[j][i];
v.push_back(sum);
}
sort(v.begin(),v.end());
for(int i=0;i<4;i++) {
if(v.size()) ans+=v.back(), v.pop_back();
}
return ans;
}
int solve2a() {
int ans=0;
int *sum,*tmp;
sum=new int[n+1];
tmp=new int[n+1];
for(int i=1;i<=n;i++) sum[i]=0;
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
sum[i]+=a[i][j];
}
}
for(int k=1;k<=m;k++) {
int tot=0;
for(int i=1;i<=n;i++) tot+=a[i][k];
for(int i=1;i<=n;i++) {
tmp[i]=sum[i]-a[i][k];
}
for(int i=0;i<3;i++) {
tot+=*max_element(tmp+1,tmp+n+1);
*max_element(tmp+1,tmp+n+1)=0;
}
ans=max(ans,tot);
}
return ans;
}
int solve2b() {
int ans=0;
int *sum,*tmp;
sum=new int[m+1];
tmp=new int[m+1];
for(int i=1;i<=m;i++) sum[i]=0;
for(int i=1;i<=m;i++) {
for(int j=1;j<=n;j++) {
sum[i]+=a[j][i];
}
}
for(int k=1;k<=n;k++) {
int tot=0;
for(int i=1;i<=m;i++) tot+=a[k][i];
for(int i=1;i<=m;i++) tmp[i]=sum[i]-a[k][i];
for(int i=0;i<3;i++) {
tot+=*max_element(tmp+1,tmp+m+1);
*max_element(tmp+1,tmp+m+1)=0;
}
ans=max(ans,tot);
}
return ans;
}
int solve3a() {
int ans=0;
int *sum,*tmp;
sum=new int[m+1];
tmp=new int[m+1];
for(int i=1;i<=m;i++) sum[i]=0;
for(int i=1;i<=m;i++) {
for(int j=1;j<=n;j++) {
sum[i]+=a[j][i];
}
}
for(int k=1;k<=n;k++) {
for(int l=1;l<=n;l++) if(k!=l) {
int tot=0;
for(int i=1;i<=m;i++) tot+=a[k][i]+a[l][i];
for(int i=1;i<=m;i++) tmp[i]=sum[i]-a[k][i]-a[l][i];
for(int i=0;i<2;i++) {
tot+=*max_element(tmp+1,tmp+m+1);
*max_element(tmp+1,tmp+m+1)=0;
}
ans=max(ans,tot);
}
}
return ans;
}
int solve3b() {
int ans=0;
int *sum,*tmp;
sum=new int[n+1];
tmp=new int[n+1];
for(int i=1;i<=n;i++) sum[i]=0;
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
sum[i]+=a[i][j];
}
}
for(int k=1;k<=m;k++) {
for(int l=1;l<=m;l++) if(k!=l) {
int tot=0;
for(int i=1;i<=n;i++) tot+=a[i][k]+a[i][l];
for(int i=1;i<=n;i++) tmp[i]=sum[i]-a[i][k]-a[i][l];
for(int i=0;i<2;i++) {
tot+=*max_element(tmp+1,tmp+n+1);
*max_element(tmp+1,tmp+n+1)=0;
}
ans=max(ans,tot);
}
}
return ans;
}
signed main() {
ios::sync_with_stdio(false);
cin>>n>>m;
a=new int*[n+1];
b=new int*[n+1];
for(int i=0;i<=n;i++) {
a[i]=new int[m+1];
b[i]=new int[m+1];
}
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
cin>>a[i][j];
b[i][j]=a[i][j];
}
}
int ans=0;
ans=max(ans,solve1a());
ans=max(ans,solve1b());
ans=max(ans,solve2a());
ans=max(ans,solve2b());
if(n