深入解析:Peft Adapter与LLM融合

在增量预训练阶段或有监督微调阶段使用高效微调方法(Lora)时会产生adapter文件,相当于是一个“补丁”。那么如何将“补丁”与原始模型合并呢?

下面将对模型合并代码进行解读。

相关代码将全部上传到github:

https://github.com/hjandlm/LLM_Train

欢迎关注公众号

深入解析:Peft Adapter与LLM融合_第1张图片

代码解读

  1. 导入包
import argparse
from loguru import logger

import torch
from peft import PeftModel, PeftConfig
from transformers import (
    AutoModel,
    AutoTokenizer,
    BloomForCausalLM,
    BloomTokenizerFast,
    AutoModelForCausalLM,
    LlamaTokenizer,
    LlamaForCausalLM,
    AutoModelForSequenceClassification,
)
  1. 模型类型
MODEL_CLASSES = {
    "bloom": (BloomForCausalLM, BloomTokenizerFast),
    "chatglm": (AutoModel, AutoTokenizer),
    "llama": (LlamaForCausalLM, LlamaTokenizer),
    "baichuan": (AutoModelForCausalLM, AutoTokenizer),
    "auto": (AutoModelForCausalLM, AutoTokenizer),
}

确定模型类型调用不同包加载模型和分词器。

  1. 超参数设置
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, required=True, type=str,
                    help="Base model name or path")
parser.add_argument('--tokenizer_path', default=None, type=str,
                    help="Please specify tokenization path.")
parser.add_argument('--lora_model', default=None, required=True, type=str,
                    help="Please specify LoRA model to be merged.")
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--output_dir', default='./merged', type=str)
args = parser.parse_args()

logger.info(f"merged_args:{args}")

超参数包括模型类型、基础模型、分词器路径、lora模型、是否修改模型词表大小、输出目录。

tokenizer_path是在词表扩充后设置,否则使用原始词表。

resize_emb是在进行扩充词表后才会使用。

  1. 模型和分词器加载
base_model_path = args.base_model
lora_model_path = args.lora_model
output_dir = args.output_dir
peft_config = PeftConfig.from_pretrained(lora_model_path)
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

# 模型加载
if peft_config.task_type == "SEQ_CLS":
    logger.info("Loading LoRA for sequence classification model")
    if args.model_type == "chatglm":
        raise ValueError("chatglm does not support sequence classification")
    base_model = AutoModelForSequenceClassification.from_pretrained(
        base_model_path,
        num_labels=1,
        load_in_8bit=False,
        torch_dtype=torch.float32,
        trust_remote_code=True,
        device_map="auto",
    )
else:
    logger.info("Loading LoRA for causal language model")
    base_model = model_class.from_pretrained(
        base_model_path,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )

# 分词器加载
if args.tokenizer_path:
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
else:
    tokenizer = tokenizer_class.from_pretrained(base_model_path, trust_remote_code=True)

# 修改词表大小
if args.resize_emb:
    base_model_token_size = base_model.get_input_embeddings().weight.size(0)
    if base_model_token_size != len(tokenizer):
        base_model.resize_token_embeddings(len(tokenizer))
        logger.info(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")

# 初始化Peft新模型
new_model = PeftModel.from_pretrained(
    base_model,
    lora_model_path,
    device_map="auto",
    torch_dtype=torch.float16,
)
  1. 模型合并
new_base_model = new_model.merge_and_unload()
  1. 模型和分词器保存
tokenizer.save_pretrained(output_dir)
new_base_model.save_pretrained(output_dir, safe_serialization=False,max_shard_size='10GB')

safe_serialization:指定是否将模型权重转换为safetensors格式以进行更安全的序列化。默认是True,指定为False,为hf格式,否则是safetensors格式。

max_shard_size:控制模型最大分片大小。

完整代码

import argparse
from loguru import logger

import torch
from peft import PeftModel, PeftConfig
from transformers import (
    AutoModel,
    AutoTokenizer,
    BloomForCausalLM,
    BloomTokenizerFast,
    AutoModelForCausalLM,
    LlamaTokenizer,
    LlamaForCausalLM,
    AutoModelForSequenceClassification,
)

MODEL_CLASSES = {
    "bloom": (BloomForCausalLM, BloomTokenizerFast),
    "chatglm": (AutoModel, AutoTokenizer),
    "llama": (LlamaForCausalLM, LlamaTokenizer),
    "baichuan": (AutoModelForCausalLM, AutoTokenizer),
    "auto": (AutoModelForCausalLM, AutoTokenizer),
}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', default=None, type=str, required=True)
    parser.add_argument('--base_model', default=None, required=True, type=str,
                        help="Base model name or path")
    parser.add_argument('--tokenizer_path', default=None, type=str,
                        help="Please specify tokenization path.")
    parser.add_argument('--lora_model', default=None, required=True, type=str,
                        help="Please specify LoRA model to be merged.")
    parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
    parser.add_argument('--output_dir', default='./merged', type=str)
    args = parser.parse_args()
    logger.info(f"merged_args:{args}")

    base_model_path = args.base_model
    lora_model_path = args.lora_model
    output_dir = args.output_dir
    peft_config = PeftConfig.from_pretrained(lora_model_path)
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    
    # 模型加载
    if peft_config.task_type == "SEQ_CLS":
        logger.info("Loading LoRA for sequence classification model")
        if args.model_type == "chatglm":
            raise ValueError("chatglm does not support sequence classification")
        base_model = AutoModelForSequenceClassification.from_pretrained(
            base_model_path,
            num_labels=1,
            load_in_8bit=False,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            device_map="auto",
        )
    else:
        logger.info("Loading LoRA for causal language model")
        base_model = model_class.from_pretrained(
            base_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map="auto",
        )
    
    # 分词器加载
    if args.tokenizer_path:
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
    else:
        tokenizer = tokenizer_class.from_pretrained(base_model_path, trust_remote_code=True)

    # 修改词表大小
    if args.resize_emb:
        base_model_token_size = base_model.get_input_embeddings().weight.size(0)
        if base_model_token_size != len(tokenizer):
            base_model.resize_token_embeddings(len(tokenizer))
            logger.info(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")

    # 初始化Peft新模型
    new_model = PeftModel.from_pretrained(
        base_model,
        lora_model_path,
        device_map="auto",
        torch_dtype=torch.float16,
    )
    new_model.eval()
    
    logger.info(f"Merging with merge_and_unload...")
    new_base_model = new_model.merge_and_unload()

    logger.info("Saving to Hugging Face format...")
    tokenizer.save_pretrained(output_dir)
    new_base_model.save_pretrained(output_dir, safe_serialization=False,max_shard_size='10GB')
    
    logger.info(f"Done! model saved to {output_dir}")

if __name__ == '__main__':
    main()

运行结果

扩充词表运行文件merge_pt.sh:

python merge_peft_adapter.py \
    --model_type llama \
    --base_model llama-2-7b-bin \
    --resize_emb \
    --tokenizer_path pt_lora_model \
    --lora_model pt_lora_model \
    --output_dir ./pt/model

不扩充词表运行merge_sft.sh:

python merge_peft_adapter.py \
    --model_type llama \
    --base_model llama-2-7b-bin \
    --lora_model pt_lora_model \
    --output_dir ./sft/model

运行结果:
深入解析:Peft Adapter与LLM融合_第2张图片

设置safe_serialization=True,运行结果:

深入解析:Peft Adapter与LLM融合_第3张图片

参考

[1] https://github.com/huggingface/peft/tree/b4faffea8ae031e5bd69a76b55418b3650c04c80

[2] https://github.com/shibing624/MedicalGPT/blob/main/merge_peft_adapter.py

往期回顾

[1] 当科技遇上神奇:用Streamlit定制AI可视化问答界面

[2] 手把手教你:LLama2模型原始权重转HF

[3] LLaMA模型之中文词表的蜕变

[4] 大模型增量预训练参数说明

[5] 大模型三阶段训练

你可能感兴趣的:(领域模型三阶段训练,人工智能,深度学习)