AlphaFold3 data_transforms 模块的
squeeze_features 函数的作用去除 蛋白质特征张量中不必要的单维度(singleton dimensions)和重复维度,以使其适配 AlphaFold3 预期的输入格式。
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
protein
(一个 包含蛋白质特征的字典)作为输入。aatype
转换为索引表示。(N, ..., 1)
的单维度。seq_length
和 num_alignments
的实际数值。aatype
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
aatype
(氨基酸类型)通常是 one-hot 编码torch.argmax(..., dim=-1)
获取 索引aatype
的数据表示,使其直接存储氨基酸索引,而不是 one-hot 矩阵。for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1] # 获取最后一维的大小
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1) # 去掉单维度
else:
protein[k] = np.squeeze(protein[k], axis=-1)
protein
特征字段,检查它们是否存在。final_dim
为 1
,说明这个维度是无意义的单维度,需要去除:
torch.Tensor
),使用 torch.squeeze(dim=-1)
。np.squeeze(axis=-1)
。seq_length
和 num_alignments
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
seq_length
和 num_alignments
可能是 列表或张量,但它们的数值其实是一个单独的整数,因此需要转换成 标量值。
1️⃣ 转换 aatype
: 从 one-hot 编码 转换成 索引表示。
2️⃣ 移除无用的单维度: 让 msa
, resolution
, deletion_matrix
等数据符合 AlphaFold3 预期格式。
3️⃣ 转换 seq_length
和 num_alignments
为标量: 确保它们不会以张量形式存在,而是整数。
最终作用:保证输入数据的维度符合 AlphaFold3 训练时的输入要求,提高数据处理效率。