Given NN integers in the range [−50000,50000][−50000,50000], how many ways are there to pick three integers aiai, ajaj, akak, such that ii, jj, kk are pairwise distinct and ai+aj=akai+aj=ak? Two ways are different if their ordered triples (i,j,k)(i,j,k)of indices are different.
The first line of input consists of a single integer NN (1≤N≤2000001≤N≤200000). The next line consists of NN space-separated integers a1,a2,…,aNa1,a2,…,aN.
Output an integer representing the number of ways.
Sample Input 1 | Sample Output 1 |
---|---|
4 1 2 3 4 |
4 |
Sample Input 2 | Sample Output 2 |
---|---|
6 1 1 3 3 4 6 |
10 |
#include
#include
#include
#include
using namespace std;
#define ll long long
//FFT模板开始
const double PI = acos(-1.0);
struct Virt
{
double r, i;
Virt(double r = 0.0, double i = 0.0)
{
this->r = r;
this->i = i;
}
Virt operator + (const Virt &x)
{
return Virt(r + x.r, i + x.i);
}
Virt operator - (const Virt &x)
{
return Virt(r - x.r, i - x.i);
}
Virt operator * (const Virt &x)
{
return Virt(r * x.r - i * x.i, i * x.r + r * x.i);
}
};
//雷德算法--倒位序
void Rader(Virt F[], int len)
{
int j = len >> 1;
for(int i = 1; i < len - 1; i++)
{
if(i < j) swap(F[i], F[j]);
int k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}
//FFT实现
void FFT(Virt F[], int len, int on)
{
Rader(F, len);
for(int h = 2; h <= len; h <<= 1)
//分治后计算长度为h的DFT
{
Virt wn( cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
//单位复根e^(2*PI/m)用欧拉公式展开
for(int j = 0; j < len; j += h)
{
Virt w(1, 0);
//旋转因子
for(int k = j; k < j + h / 2; k++)
{
Virt u = F[k];
Virt t = w * F[k + h / 2];
F[k] = u + t;
//蝴蝶合并操作
F[k + h / 2] = u - t;
w = w * wn;
//更新旋转因子
}
}
}
if(on == -1)
for(int i = 0; i < len; i++)
F[i].r /= len;
}
//求卷积
void Conv(Virt a[], Virt b[], int n)
{
FFT(a, n, 1);
FFT(b, n, 1);
for(int i = 0; i < n; i++)
a[i] = a[i] * b[i];
FFT(a, n, -1);
}
//FFT模板结束
const int T = 50000;
const int MAXN = 200000;
ll num[MAXN+5];
ll cnt[MAXN+5];
Virt a[2*MAXN+5];
Virt b[2*MAXN+5];
ll ans[2*MAXN+5];
int main()
{
int n;
scanf("%d",&n);
int zero = 0;
for(int i = 0; i < n; i++)
{
scanf("%lld",&num[i]);
if(num[i] == 0) zero++;
cnt[num[i] + T]++;
}
int len = 1;
while(len < MAXN) len <<= 1;
for(int i = 0; i < MAXN; i++)
{
a[i] = b[i] = Virt(1.0 * cnt[i], 0.0);
}
Conv(a, b, len);
for(int i = 0; i < len; i++)
{
ans[i] = (ll)(a[i].r + 0.5);
}
for(int i = 0; i < n; i++)
{
ans[(num[i] + T) * 2]--;
}
ll res = 0;
for(int i = 0; i < n; i++)
{
res += ans[num[i]+T*2];
res -= (zero - (num[i] == 0)) * 2;
}
printf("%lld\n",res);
return 0;
}