返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

100行代码演示LoRA fine-tuning!

[复制链接]
链载Ai 显示全部楼层 发表于 10 小时前 |阅读模式 打印 上一主题 下一主题

ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;display: table;border-bottom: 2px solid rgb(15, 76, 129);color: rgb(63, 63, 63);visibility: visible;">引言

LoRA fine-tuning,冻结模型参数,引入低阶矩阵,更新一小部分权重来适应特定任务。通过减少训练参数的数量,降低了计算成本和内存需求,同时,保留了预训练模型的知识和泛化能力。

ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;border-left: 3px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">LoRA方案中,提到的rank和alpha是什么?

之前的文章解释了rank是什么,这里再重新回顾下,同时,也解释下alpha超参数的含义。

在LoRA(Low-Rank Adaptation)fine-tuning中,alpha(α)和rank(秩)是超参数,共同决定了fine-tuning过程中,权重更新的效率和效果。

Alpha(α)的作用是控制LoRA权重更新的规模,是一个缩放因子,用于调整LoRA权重(W_A和W_B)在前向传播中的影响程度。通过调整alpha的值,可以控制fine-tuning过程中,对原始预训练模型权重的修改程度。如果alpha设置得较大,那么LoRA权重的变化对模型的影响也会更大;反之则反。

Rank(秩)在LoRA中指的是在LoRA权重矩阵分解中使用的秩。秩决定了LoRA权重矩阵分解中两个矩阵(W_A和W_B)的维度。较低的秩意味着更少的参数需要更新,从而减少了计算资源的需求和内存占用。然而,如果秩太低,可能无法充分捕捉到fine-tuning任务所需的特征变化,从而影响模型性能。因此,选择合适的rank也非常重要,实际项目中,需要不断验证调整,得到较好的结果。

Alpha和rank之间的关系是相互影响的,在实际操作中,调整alpha和rank的相对大小可以改变fine-tuning的效果。例如,如果rank保持不变,增加alpha会增强fine-tuning的效果,而减少alpha则会减弱这些效果。

LoRA的实现中,原始的模型权重保持不变,而是通过在前向传播过程中引入W_A和W_B的乘积来模拟权重的更新。这个过程可以表示为:


h=x@(W_A@W_B)*α


其中,h是模型的输出,x是输入,W_A和W_B是LoRA权重,α是缩放因子。通过这种方式,α直接影响了LoRA权重更新,对模型输出的贡献程度。如果α设置得较大,那么W_A和W_B的乘积对输出的影响就会更大,从而在fine-tuning过程中更强烈地调整模型以适应特定任务。相反,如果α设置得较小,那么权重更新的影响就会减弱,模型的变化就会更保守。

ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;border-left: 3px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">利用peft,transformer等库,实现LoRA微调

#step1安装依赖!pip install trl transformers accelerate datasets bitsandbytes einops torch huggingface-hub git+https://github.com/huggingface/peft.git
from datasets import load_datasetfrom random import randrangeimport torchfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM,TrainingArguments,pipelinefrom peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLMfrom trl import SFTTrainer
# 训练数据使用:https://huggingface.co/datasets/samsumdataset = load_dataset("samsum")
model_name = "google/flan-t5-small"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#Makes training faster but a little less accurate model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
#setting padding instructions for tokenizertokenizer.pad_token = tokenizer.eos_tokentokenizer.padding_side="right"
def prompt_instruction_format(sample):#格式化promptreturn f"""### Instruction:Use the Task below and the Input given to write the Response:
### Task:Summarize the Input
### Input:{sample['dialogue']}
### Response:{sample['summary']}"""
# Create the trainertrainingArgs = TrainingArguments(output_dir='output',num_train_epochs=1,per_device_train_batch_size=4,save_strategy="epoch",learning_rate=2e-4)
peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",)
trainer = SFTTrainer(model=model,train_dataset=dataset['train'],eval_dataset = dataset['test'],peft_config=peft_config,tokenizer=tokenizer,packing=True,formatting_func=prompt_instruction_format,args=trainingArgs,)
trainer.train()

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">训练过程建议使用google colab,加载数据集等操作,速度飞快。

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ