在大模型训练中,数据准备是最关键的步骤之一。俗话说 “Garbage in, garbage out”,数据的质量直接决定了模型的性能上限。本文将详细介绍如何为 LLaMA-Factory 准备训练数据,包括文本数据和多模态数据的处理方法。
LLaMA-Factory 支持两种主要的数据格式:
instruction
、input
、output
字段[
{
"instruction": "请解释什么是机器学习",
"input": "",
"output": "机器学习是人工智能的一个分支,它使计算机能够从数据中学习并改进性能,而无需明确编程。",
"system": "你是一个专业的AI助手",
"history": [
["你好", "你好!有什么可以帮助你的吗?"]
]
}
]
[
{
"conversations": [
{
"from": "human",
"value": "什么是深度学习?"
},
{
"from": "gpt",
"value": "深度学习是机器学习的一个子领域,使用多层神经网络来学习数据的复杂模式。"
}
],
"system": "你是一个专业的AI助手"
}
]
[
{"text": "这是一段用于预训练的文本内容"},
{"text": "另一段文本内容"}
]
[
{
"conversations": [
{
"from": "human",
"value": "描述这张图片"
},
{
"from": "gpt",
"value": "这是一只可爱的小猫,正在阳光下睡觉。"
}
],
"images": [
"path/to/cat_image.jpg"
]
}
]
标记的数量必须与 images
列表中的图片数量一致[
{
"conversations": [
{
"from": "human",
"value": "
},
{
"from": "gpt",
"value": "视频展示了一个人在公园里慢跑的场景。"
}
],
"videos": [
"path/to/running_video.mp4"
]
}
]
[
{
"conversations": [
{
"from": "human",
"value": "
},
{
"from": "gpt",
"value": "音频内容是:今天天气真好,适合出去散步。"
}
],
"audios": [
"path/to/audio_file.wav"
]
}
]
import os
import json
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
# 设置镜像加速
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
def convert_pokemon_dataset(output_dir="pokemon_sharegpt"):
os.makedirs(output_dir, exist_ok=True)
images_dir = os.path.join(output_dir, "images")
os.makedirs(images_dir, exist_ok=True)
ds = load_dataset("BUAADreamer/pokemon-gpt4-1k", split="train")
sharegpt_data = []
for idx, item in enumerate(tqdm(ds, desc="处理数据")):
entry = {
"conversations": [],
"system": "You are a helpful assistant.",
"images": [f"images/pokemon_{idx:04d}.png"]
}
# 保存图片
if "images" in item and item["images"]:
image = item["images"][0]
if isinstance(image, Image.Image):
image.save(os.path.join(images_dir, f"pokemon_{idx:04d}.png"))
# 转换对话
if "messages" in item:
for i, msg in enumerate(item["messages"]):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
from_role = "human"
if i == 0:
content += ""
elif role == "assistant":
from_role = "gpt"
else:
continue
entry["conversations"].append({
"from": from_role,
"value": content
})
sharegpt_data.append(entry)
output_path = os.path.join(output_dir, "pokemon_sharegpt.json")
with open(output_path, "w", encoding="utf-8") as f:
json.dump(sharegpt_data, f, ensure_ascii=False, indent=2)
print(f"✅ 数据转换完成,已保存至: {output_path}")
if __name__ == "__main__":
convert_pokemon_dataset()
代码运行成功后,会在pokemon_sharegpt
文件夹下生成一个images
文件夹,以及一个pokemon_sharegpt.json
的文件,json内容如下
{
"conversations": [
{
"from": "human",
"value": "Provide caption for the image in one sentence. Be detailed but precise."
},
{
"from": "gpt",
"value": "A cheerful Bulbasaur ready for its next Pokémon adventure."
}
],
"system": "You are a helpful assistant.",
"images": [
"images/pokemon_0000.png"
]
}
对应图片如下:
dataset_info.json
是 LLaMA-Factory 用于管理和加载数据集的核心配置文件。所有自定义数据集都必须在这个文件中注册后才能使用。该文件应放置在 dataset_dir
目录下(后续可以放置在 pokemon_sharegpt
文件夹下)。
{
"数据集名称1": {
// 数据集配置
},
"数据集名称2": {
// 数据集配置
}
}
参数名 | 类型 | 必填 | 说明 | 示例 |
---|---|---|---|---|
file_name | string | 是* | 本地数据集文件路径 | "data/train.json" |
hf_hub_url | string | 否 | Hugging Face Hub 仓库名称 | "tatsu-lab/alpaca" |
ms_hub_url | string | 否 | ModelScope Hub 仓库名称 | "modelscope/alpaca" |
formatting | string | 否 | 数据格式,默认为 alpaca | "alpaca" 或 "sharegpt" |
ranking | boolean | 否 | 是否为偏好数据集 | true 或 false |
subset | string | 否 | 数据集子集名称 | "train" |
split | string | 否 | 数据集分割,默认为 train | "train" , "test" , "validation" |
num_samples | integer | 否 | 限制使用的样本数量 | 10000 |
注:如果指定了
hf_hub_url
或ms_hub_url
,则可以不填file_name
columns
用于将数据集中的字段名映射到 LLaMA-Factory 标准字段名:
标准字段 | 说明 | Alpaca 默认值 | ShareGPT 默认值 |
---|---|---|---|
prompt | 用户指令 | "instruction" |
- |
query | 用户输入(补充) | "input" |
- |
response | 模型回复 | "output" |
- |
messages | 对话消息列表 | - | "conversations" |
system | 系统提示词 | "system" |
"system" |
history | 历史对话 | "history" |
- |
images | 图片路径列表 | "images" |
"images" |
videos | 视频路径列表 | "videos" |
"videos" |
audios | 音频路径列表 | "audios" |
"audios" |
tools | 工具描述 | "tools" |
"tools" |
chosen | 偏好数据集中的优选回复 | "chosen" |
"chosen" |
rejected | 偏好数据集中的劣选回复 | "rejected" |
"rejected" |
kto_tag | KTO 训练的标签 | "kto_tag" |
"kto_tag" |
tags
用于定义 ShareGPT 格式中的标签名称:
参数 | 说明 | 默认值 |
---|---|---|
role_tag | 角色标识字段名 | "from" |
content_tag | 内容字段名 | "value" |
user_tag | 用户角色标识 | "human" |
assistant_tag | 助手角色标识 | "gpt" |
observation_tag | 工具返回结果标识 | "observation" |
function_tag | 函数调用标识 | "function_call" |
system_tag | 系统消息标识 | "system" |
{
"my_text_dataset": {
"file_name": "data/my_dataset.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"system": "system_prompt"
}
}
}
{
"pokemon_multimodal": {
"file_name": "pokemon_sharegpt/pokemon_sharegpt.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"system": "system",
"images": "image_paths"
}
}
}
{
"advanced_dataset": {
"hf_hub_url": "BUAADreamer/pokemon-gpt4-1k",
"ms_hub_url": "modelscope/pokemon-gpt4",
"file_name": "data/pokemon.json",
"formatting": "sharegpt",
"ranking": false,
"subset": "default",
"split": "train",
"num_samples": 5000,
"columns": {
"messages": "chats",
"system": "sys_prompt",
"images": "img_list",
"tools": "functions"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
}
}
{
"preference_dataset": {
"file_name": "data/preferences.json",
"formatting": "alpaca",
"ranking": true,
"columns": {
"prompt": "question",
"query": "context",
"chosen": "better_answer",
"rejected": "worse_answer"
}
}
}
def validate_dataset(json_path):
"""验证数据集格式是否正确"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
issues = []
for idx, item in enumerate(data):
if 'conversations' not in item:
issues.append(f"第 {idx} 条缺少 conversations 字段")
if len(item['conversations']) % 2 != 0:
issues.append(f"第 {idx} 条对话轮次为奇数")
if 'images' in item:
for img_path in item['images']:
if not os.path.exists(img_path):
issues.append(f"第 {idx} 条的图片不存在: {img_path}")
return issues
、
、
插入媒体位置def process_large_dataset(dataset, batch_size=100):
total = len(dataset)
for i in range(0, total, batch_size):
batch = dataset[i:i+batch_size]
process_batch(batch)
if i % 1000 == 0:
save_checkpoint(i)
问题 | 解决方案 |
---|---|
图片加载失败 | 检查路径正确性,确保图片格式支持 |
内存溢出 | 使用流式处理,避免一次性加载所有数据 |
格式不匹配 | 使用验证脚本检查结构 |
编码错误 | 统一使用 UTF-8 编码 |
准备高质量的训练数据是大模型成功的关键。本文介绍了:
尽可能保证: