AF3 block_delete_msa函数解读

AlphaFold3  data_transforms 模块的 block_delete_msa 函数用于从 MSA(多序列比对)中删除多个块(block)序列,以模拟 MSA 数据的缺失或噪声。

  • 这在 训练时 可能用于数据增强(防止模型对 MSA 过度依赖)。
  • 推理(inference)时不会使用,因为完整 MSA 对预测准确性很重要。

源代码:

# Not used in inference
@curry1
def block_delete_msa(protein, config):
    num_seq = protein["msa"].shape[0]
    block_num_seq = torch.floor(
        torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
        * config.msa_fraction_per_block
    ).to(torch.int32)

    if int(block_num_seq) == 0:
        return protein

    if config.randomize_num_blocks:
        nb = int(torch.randint(
            low=0,
            high=config.num_blocks + 1,
            size=(1,),
            device=protein["msa"].device,
        )[0])
    else:
        nb = config.num_blocks

    del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
    del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
    del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
    del_indices = torch.unique(torch.reshape(del_blocks, [-1]))

    # Make sure we keep the original sequence
    combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).lon

你可能感兴趣的:(深度学习,人工智能,python,生物信息学,pytorch)