SFT实战微调Gemma

1.运行环境搭建:

虚拟环境实验要求:

  • python 3.10及以上版本
  • pytorch 1.12及以上版本,推荐2.0及以上版本
  • 建议使用CUDA 11.4及以上
  • transformers>=4.38.0
    请务必采用上面的环境,否则代码无法运行。如果python环境低于3.10则需要额外安装pip install bitsandbytes

Gemma模型链接和下载:

支持直接下载模型的repo(以7b-it为例,服务器性能低建议2b模型进行Demo实验):

from modelscope import snapshot_download
model_dir = snapshot_download("AI-ModelScope/gemma-7b-it")

2.SFT微调

SFT Trainer 是transformers.Trainer的子类,增加了处理PeftConfig的逻辑 .
根据不同需求则训练策略不同,下面是几个样例:

2.1在数据集合上二次预训练,对整个序列进行微调

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)
trainer.train()

注:dataset_text_field= " text "。dataset_text_field参数用于指示数据集中哪个字段包含作为模型输入的文本数据。它使datasets 库能够基于该字段中的文本数据自动创建ConstantLengthDataset,简化数据准备过程

2.2 仅在响应数据集合上进行二次微调

需要设置响应的模版: response_template = " ### Answer:"

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

trainer.train()

2.3、在对话数据集合上进行微调

需要设置指令模版和响应模版:

instruction_template = “### Human:”

response_template = “### Assistant:”

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

model 

你可能感兴趣的:(深度学习,人工智能)