返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

我们在这里将直接使用 LLaMA-Factory[3] 训练框架来直接完成监督微调部分工作

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 09:31 |阅读模式 打印 上一主题 下一主题

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.2em;font-weight: bold;display: table;margin: 2em auto 1em;padding-right: 1em;padding-left: 1em;border-bottom: 2px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">Google Gemma 2B 微调实战(IT科技新闻标题生成)

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;border-left: none;padding: 1em;border-radius: 8px;color: rgba(0, 0, 0, 0.5);background: rgb(247, 247, 247);margin: 2em 8px;">

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1em;letter-spacing: 0.1em;color: rgb(80, 80, 80);">本文我将使用 Google 的 Gemma-2b 模型来微调一个基于IT科技新闻正文来生成对应标题的模型。并且我将介绍如何使用高度集成的训练框架来进行快速微调。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.2em;font-weight: bold;display: table;margin: 4em auto 2em;padding-right: 0.2em;padding-left: 0.2em;background: rgb(15, 76, 129);color: rgb(255, 255, 255);">开始前

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">为了尽可能简化整个流程,我将使用linux-cn 数据集[1]作为本次训练任务的训练数据。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">模型选择使用Gemma-2b[2],在目前这个任务中 2b 级别的参数模型已经完全能满足当前的需求,当然你也可以尝试使用 7b 的模型。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">我们在这里将直接使用LLaMA-Factory[3]训练框架来直接完成监督微调部分工作。当然该框架不仅支持监督微调(SFT)也支持预训练(PT)、奖励模型(RM)以及 PPO/DPO 的训练。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.2em;font-weight: bold;display: table;margin: 4em auto 2em;padding-right: 0.2em;padding-left: 0.2em;background: rgb(15, 76, 129);color: rgb(255, 255, 255);">数据整理

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">linux-cn 数据集本身已经进行了数据的清洗和格式化,这一步我们只需要把我们需要的字段提取出后来后根据一定格式转换为 LLaMA-Factory 监督微调格式即可。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">在本任务中,我们只需要数据集中的“title”和“content”两个字段即可。而 LLaMA-Factory 监督微调格式是如下格式的json文件。

[
{
"instruction":"Whatarethethreeprimarycolors?",
"input":"",
"output":"Thethreeprimarycolorsarered,blue,andyellow.Thesecolorsarecalledprimarybecausetheycannotbecreatedbymixingothercolorsandallothercolorscanbemadebycombiningtheminvariousproportions.Intheadditivecolorsystem,usedforlight,theprimarycolorsarered,green,andblue(RGB).",
},
...
]

因为我们选择使用的是预训练模型,所以我们还需要指定一个 prompt template。指定 prompt template 的一个好处是你如果希望同时训练多个不同类型的任务,这样可以保证不同任务之间不会相互干扰。

完整代码如下:

importjson

result=[]

prompt_template="""Generateatitleforthearticle:

{content}

---
Title:
"""
withopen('archve.jsonl','r')asf:
forlineinf:
p=json.loads(line)
result.append({
"instruction":prompt_template.replace("{content}",p['content']),
"input":"",
"output":p['title']
})

withopen('itnews_data.json','w')asf:
json.dump(result,f,ensure_ascii=False,indent=4)

完成这一步后,我们就可以开始训练我们的模型了。但往往耗费时间最长以及最头疼的也是数据收集和数据整理这一部分。

模型微调

首先你需要保证 LLaMA-Factory 框架已经在你本地已经 ready 了。即你已经下载了该项目并且已经进行了项目的安装。

具体如何安装你可以查看该项目的 README,本文不再过多赘述。

首先我们需要将数据集移动到框架的data目录中,然后在dataset_info.json中添加我们自定义的数据集。

以下是本文实例所添加的数据集信息:

"itnews":{
"file_name":"itnews_data.json",
},

当然不同类型的任务该框架会有不同的数据集格式要求,你可以参考项目中dataset_info.json的README[4]

然后我们只需要执行如下命令就可以开始微调了,本文是在单张A100(80G)上进行的微调。

CUDA_VISIBLE_DEVICES=0pythonsrc/train_bash.py\
--stagesft\
--do_trainTrue\
--model_name_or_pathgoogle/gemma-2b\
--finetuning_typelora\
--templatedefault\
--datasetitnews\
--use_unsloth\
--cutoff_len8192\
--learning_rate5e-05\
--num_train_epochs10.0\
--max_samples10000\
--per_device_train_batch_size4\
--per_device_eval_batch_size4\
--gradient_accumulation_steps4\
--lr_scheduler_typecosine\
--max_grad_norm1.0\
--logging_steps10\
--save_steps100\
--eval_steps100\
--evaluation_strategysteps\
--warmup_steps0\
--output_dirsaves/Gemma-2B/lora/train_v1\
--bf16True\
--lora_rank8\
--lora_dropout0.1\
--lora_targetq_proj,v_proj\
--val_size0.1\
--load_best_model_at_endTrue\
--plot_lossTrue\
--report_to"tensorboard"

在这里我需要对其中的几个参数进行简短的介绍:

--stage即任务类型,在这里我们本文做的是监督微调所以是 sft,如果是其他任务你需要指定不同的类型。

--dataset即数据集,这里的名称就是我们在dataset_info.json文件中指定的数据集名称。

--use_unsloth这是一个训练加速器,官方宣称在 Gemma 7b 上拥有 2.4x 的加速,并且节省超一半的显存。在使用这个之前你需要按照官方文档[5]进行安装。

--cutoff_len文本令牌化后输入到模型的截止长度,因为本文使用的 Gemma 2b 模型,它的最大长度是 8192 ,所以在这里我设置的是 8192。但请记住更长的上下文也需要更多的 GPU 显存!

--max_samples设置数据集加载的最大条数。本参数主要用作调试目的时非常好用,尤其是在你不确定cutoff_lenbatch_size的时候,你可以加载很小的一部分数据进行测试,然后查看你显存的使用情况。

--learning_rate--num_train_epochs学习率和训练周期,这是一个经验值,一般通过查看模型的 loss 来调整,当然在 LLM 模型训练中,本参数主要以模型是否符合任务需求而决定,也就是说完美的 loss 可能并不满足需求。

--per_device_train_batch_size--per_device_eval_batch_size--gradient_accumulation_steps这三个参数需要根据你的显存大小以及是否使用多个GPU等条件进行不同的调整。

--output_dir模型保存的目录。

更多的参数解释可以查看项目说明[6],以及transformers Trainer的说明[7]

模型使用

在这里我们可以直接使用transformers来执行。

fromtransformersimportAutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig

peft_model_id="checkpoint-2000"
model=AutoModelForCausalLM.from_pretrained(peft_model_id,device_map="cuda")

tokenizer=AutoTokenizer.from_pretrained("google/gemma-2b")

input_text="""
Generateatitleforthearticle:

{content}

---
Title:
"""#固定格式
encoding=tokenizer(input_text,return_tensors="pt").to("cuda")

outputs=model.generate(**encoding,max_length=8192,temperature=0.2,do_sample=True)
generated_ids=outputs[:,encoding.input_ids.shape[1]:]
generated_texts=tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
print(generated_texts[0])

我通过使用我自己的一篇差不多 5000 tokens 关于微服务的文章[8]进行测试,并且这篇文章没有出现在数据集中。

在使用相同 prompt 的情况下的输出:

gemma-2b-it

>微服务架构
概述
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
...

lora

>微服务架构的优势

通过简单的测试,不难发现模型在微调后,其返回格式上更加稳定,并且更加符合我们的要求。

总结

如果你不想训练,但又希望尝试本文中的模型,你可以在huggingface上搜索gemma-2b-technology-news-title-generation-lora[9],找到从100-2200 steps 的所有 checkpoint。

本文使用了一种相对简单的方式来训练符合自己需求的模型。在真实的企业场景中往往还涉及如何生成符合需求的数据集,集群训练,模型的AB测试,企业级部署等问题。我会在未来的文章中和大家分享。


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ