【权重小技巧(2)】模型权重文件总结: .bin、.safetensors、.pt的保存、加载方法一览

.bin 文件通常是通过 torch.save(state_dict) 来保存的,因此需要使用 torch.load() 进行加载。这是因为 .bin 文件存储的是模型权重的 state_dict(即参数的字典),而不是完整的模型对象。你需要先加载 state_dict,然后将其加载到模型中,类似于 .pt.pth 文件。

常见的权重文件格式(如 .bin.safetensors.pt 等)的保存和加载方法:

文件类型 保存方法 加载方法 描述
.bin torch.save(model.state_dict(), 'model.bin') state_dict = torch.load('model.bin', map_location='cpu')
model.load_state_dict(state_dict)
.bin 文件通常存储 state_dict,仅保存模型的权重。需要先加载 state_dict,然后用 model.load_state_dict 加载到模型。
.pt / .pth torch.save(model.state_dict(), 'model.pth') state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)
.pt / .pth 文件是 PyTorch 中常见的保存权重格式,保存的是模型的 state_dict。加载方式和 .bin 文件相同。
.pt (完整模型) torch.save(model, 'model.pth') model = torch.load('model.pth', map_location='cpu') 保存的是整个模型对象(包括结构和权重)。加载时会直接返回模型,无需再用 load_state_dict
.safetensors from safetensors import save_file
save_file(tensors, 'model.safetensors')
from safetensors import load_file
state_dict = load_file('model.safetensors')
model.load_state_dict(state_dict)
.safetensors 是一种新型格式,专门用于安全和高效地保存模型权重。保存和加载都需要 safetensors 库。
Hugging Face .bin model.save_pretrained('path_to_model') from transformers import AutoModel
model = AutoModel.from_pretrained('path_to_model')
Hugging Face 的 .bin 通常用于保存预训练模型的权重,带有额外的配置文件(如 config.json)。AutoModel 会自动处理加载。
.h5 (Keras/TensorFlow) model.save_weights('model.h5') model.load_weights('model.h5') Keras 和 TensorFlow 使用 .h5 文件格式保存模型权重。适用于 TensorFlow/Keras 环境。

详细解释:

  1. .bin 文件

    • .bin 文件通常存储模型的 state_dict,也就是模型参数的字典。你需要先用 torch.load() 读取这个字典,再用 model.load_state_dict() 将它加载到模型中。
    • 典型保存代码:torch.save(model.state_dict(), 'model.bin')
    • 加载代码:
      state_dict = torch.load('model.bin', map_location='cpu')
      model.load_state_dict(state_dict)
      
  2. .pt / .pth 文件

    • .pt.pth 都是 PyTorch 官方推荐的扩展名,用于保存模型的 state_dict 或完整模型。对于 state_dict,它的加载方式与 .bin 类似。

    • 典型保存代码:torch.save(model.state_dict(), 'model.pth')

    • 加载代码:

      state_dict = torch.load('model.pth', map_location='cpu')
      model.load_state_dict(state_dict)
      
    • 保存完整模型对象的代码:torch.save(model, 'model.pth')

    • 加载完整模型代码:

      model = torch.load('model.pth', map_location='cpu')
      
  3. .safetensors 文件

    • .safetensors 是一种新型格式,设计用于比 PyTorch 的 .pt 文件更加安全、确定性强(防止任意代码执行的漏洞)且内存高效。它适合大型模型的存储。
    • 典型保存代码:
      from safetensors import save_file
      save_file(tensors, 'model.safetensors')
      
    • 加载代码:
      from safetensors import load_file
      state_dict = load_file('model.safetensors')
      model.load_state_dict(state_dict)
      
  4. Hugging Face .bin

    • Hugging Face 通常保存预训练模型的权重为 .bin 文件,同时保存模型配置文件 config.json,并提供 save_pretrainedfrom_pretrained 方法简化保存和加载过程。
    • 典型保存代码:model.save_pretrained('path_to_model')
    • 加载代码:
      from transformers import AutoModel
      model = AutoModel.from_pretrained('path_to_model')
      

总结

  • .bin / .pt / .pth:这些格式大多数情况下保存的是 state_dict,需要通过 model.load_state_dict() 将其加载到模型中。
  • 完整模型 (.pt):保存的是整个模型对象,包含模型架构和权重,加载时直接得到模型实例。
  • .safetensors:专为安全和效率设计,需要 safetensors 库处理加载和保存。

你可能感兴趣的:(编程学习,模型部署,人工智能,pytorch,python,机器学习,AI)