对于任意群元素 g ∈ G g \in G g∈G 和输入输出空间 V i n , V o u t V_{in}, V_{out} Vin,Vout,满足:
f ( ρ i n ( g ) x ) = ρ o u t ( g ) f ( x ) f(\rho_{in}(g)x) = \rho_{out}(g)f(x) f(ρin(g)x)=ρout(g)f(x)
其中 ρ \rho ρ 表示群表示,在蛋白质折叠场景中:
设原子位置为 x i x_i xi,特征为 h i h_i hi,注意力得分为:
α i j = Softmax j ( Q ( h i ) T K ( h j ) d ⋅ e − ∥ x i − x j ∥ 2 / 2 σ 2 ) \alpha_{ij} = \text{Softmax}_j \left( \frac{Q(h_i)^T K(h_j)}{\sqrt{d}} \cdot e^{-\|x_i - x_j\|^2/2\sigma^2} \right) αij=Softmaxj(dQ(hi)TK(hj)⋅e−∥xi−xj∥2/2σ2)
等变更新规则:
Δ x i = ∑ j α i j V ( h j ) ⊗ ( x i − x j ) \Delta x_i = \sum_j \alpha_{ij} V(h_j) \otimes (x_i - x_j) Δxi=j∑αijV(hj)⊗(xi−xj)
( ⊗ \otimes ⊗ 表示向量外积保持等变性)
import torch
from torch import nn
from torch.nn import functional as F
class EquivariantLinear(nn.Module):
"""等变性全连接层"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
self.bias = nn.Parameter(torch.randn(out_dim))
def forward(self, x, coord):
# x: [B, N, C], coord: [B, N, 3]
h = torch.einsum('bnc,cd->bnd', x, self.weight) + self.bias
# 保持坐标不变性
return h, coord
class EquiAttention(nn.Module):
"""等变注意力层"""
def __init__(self, dim, heads=4):
super().__init__()
self.heads = heads
self.scale = (dim // heads) ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3)
self.pos_proj = nn.Linear(3, dim // heads)
def forward(self, x, coord):
B, N, _ = x.shape
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.reshape(B, N, self.heads, -1), qkv)
pos_enc = self.pos_proj(coord) # [B, N, H, D/H]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn += torch.einsum('bnhd,bmhd->bnmh', pos_enc, pos_enc)
attn = attn.softmax(dim=-1)
out = torch.einsum('bnmh,bmhd->bnhd', attn, v)
out = out.reshape(B, N, -1)
# 坐标更新
delta_coord = torch.einsum('bnmh,bmc->bnhc', attn, coord)
delta_coord = delta_coord.mean(dim=2) # [B, N, 3]
return out, coord + delta_coord
方案:将等变注意力作为核心模块构建预测网络
数据集:CASP14比赛数据(178个蛋白质)
效果指标:
案例:某药企使用等变注意力进行分子生成
成果:
# 使用Optuna进行自动调参示例
import optuna
def objective(trial):
config = {
"num_layers": trial.suggest_int("num_layers", 6, 12),
"hidden_dim": trial.suggest_categorical("hidden_dim", [128, 256, 512]),
"attention_heads": trial.suggest_int("heads", 4, 8),
"l_max": trial.suggest_int("l_max", 1, 3) # 旋转阶数
}
model = build_model(config)
return train_eval(model)
# 内存高效的注意力计算
from xformers.ops import memory_efficient_attention
attn_out = memory_efficient_attention(q, k, v)
scaler = torch.cuda.amp.GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
outputs = model(inputs)
loss = criterion(outputs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
EquiAFold(ICLR 2024):
SE3-Diffusion(NeurIPS 2023):
高阶等变表示:
将球谐函数的阶数扩展至 l = 4 l=4 l=4,能更好建模多原子相互作用:
ϕ ( x ) = ∑ l = 0 4 ∑ m = − l l c l , m Y l m ( x ) \phi(x) = \sum_{l=0}^4 \sum_{m=-l}^l c_{l,m}Y_l^m(x) ϕ(x)=l=0∑4m=−l∑lcl,mYlm(x)
几何不变损失函数:
提出基于曲率的正则项:
L c u r v = ∑ i ∥ ∇ 2 E ( x i ) ∥ 2 \mathcal{L}_{curv} = \sum_i \| \nabla^2 E(x_i) \|^2 Lcurv=i∑∥∇2E(xi)∥2
使能量曲面更符合物理规律
实践建议:使用OpenFold代码库作为基础框架(github.com/aqlaboratory/openfold),结合等变注意力模块进行二次开发。在RTX 4090显卡上训练时,batch_size建议设置为8-16,学习率使用余弦退火从1e-4降至1e-6。