官方题解给出的解法是先预处理出行和列取0到k的最大值,然后从0到k枚举,找到最大的r[i]+c[i]-i*(k-i)*p。。。
#include <iostream> #include <sstream> #include <algorithm> #include <vector> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <cstring> #include <cstdlib> #include <cmath> #include <climits> #define maxn 1000005 #define eps 1e-6 #define mod 10007 #define INF 99999999 #define lowbit(x) (x&(-x)) //#define lson o<<1, L, mid //#define rson o<<1 | 1, mid+1, R typedef long long LL; using namespace std; struct node { LL x; bool operator < (const node &a) const { return a.x>x; } }tmp; priority_queue<node> q1, q2; LL R[1005], C[1005]; LL r[maxn], c[maxn]; int main(void) { int n, m, k, p; int i, j; LL ans, a; scanf("%d%d%d%d", &n, &m, &k, &p); for(i = 1; i <= n; i++) for(j = 1; j <= m; j++) { scanf("%I64d", &a); R[i] += a; C[j] += a; } for(i = 1; i <= n; i++) { tmp.x = R[i]; q1.push(tmp); } for(i = 1; i <= m; i++) { tmp.x = C[i]; q2.push(tmp); } for(i = 1; i <= k; i++) { tmp = q1.top(); q1.pop(); r[i] = r[i-1] + tmp.x; tmp.x -= m*p; q1.push(tmp); } for(i = 1; i <= k; i++) { tmp = q2.top(); q2.pop(); c[i] = c[i-1] + tmp.x; tmp.x -= n*p; q2.push(tmp); } ans = max(r[k], c[k]); for(i = 1; i <= k; i++) ans = max(ans, r[i] + c[k-i] - 1ll*(k-i)*i*p); printf("%I64d\n", ans); return 0; }