返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

如何在 Keras 中使用 LoRA 微调 Gemma 模型

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 09:34 |阅读模式 打印 上一主题 下一主题

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.2em;font-weight: bold;display: table;margin: 2em auto 1em;padding-right: 1em;padding-left: 1em;border-bottom: 2px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">在 Keras 中使用 LoRA 微调 Gemma 模型

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">大型语言模型(LLM)如 Gemma 已被证明在多种自然语言处理(NLP)任务上有效。LLM首先通过自监督方式在大量文本语料上进行预训练。预训练帮助 LLM 学习通用知识,例如单词之间的统计关系。然后,可以使用特定领域的数据对 LLM 进行微调,以执行下游任务(如情感分析)。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">LLM 的大小极大(参数数量达到数百万)。对于大多数应用来说,不需要进行完全微调(更新模型中的所有参数),因为典型的微调数据集相对于预训练数据集要小得多。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">低秩适应(LoRA)是一种微调技术,通过冻结模型的权重并在模型中插入较少数量的新权重,大大减少了下游任务的可训练参数数量。这使得使用 LoRA 进行训练更快、更节省内存,并且生成的模型权重更小(几百MB),同时保持了模型输出的质量。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">本教程将引导您使用 KerasNLP 对 Gemma 2B 模型进行 LoRA 微调,使用的是 Databricks Dolly 15k数据集。该数据集包含15,000个高质量的人类生成的提示/响应对,专门用于微调 LLM。

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.2em;font-weight: bold;display: table;margin: 4em auto 2em;padding-right: 0.2em;padding-left: 0.2em;background: rgb(15, 76, 129);color: rgb(255, 255, 255);">设置

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 1.1em;font-weight: bold;margin-top: 2em;margin-right: 8px;margin-bottom: 0.75em;padding-left: 8px;border-left: 3px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">Gemma 设置

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">要完成本教程,您首先需要按照 Gemma 设置的说明完成设置。Gemma 设置说明将向您展示如何进行以下操作:

ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;margin: 1.5em 8px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">Gemma 模型由 Kaggle 托管。要使用 Gemma,请在 Kaggle 上请求访问权限:

    ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;padding-left: 1em;list-style: circle;color: rgb(63, 63, 63);" class="list-paddingleft-1">
  • •在 kaggle.com登录或注册。

  • •打开 Gemma 模型卡片并选择“请求访问权限”。

  • •完成同意表格并接受条款和条件。

安装依赖

安装 Keras 、 KerasNLP 和其他依赖。

#安装最新的Keras3。更多信息查看https://keras.io/getting_started/。

!pipinstall-q-Ukeras-nlp
!pipinstall-q-Ukeras>=3

选择一个后端

Keras 是一个高级的、多框架的深度学习API,设计上注重简单性和易用性。Keras 3 允许您选择后端:TensorFlow、JAX或 PyTorch。这三个后端对于本教程都适用。

在本教程中,我们使用 JAX 作为后端。

importos

os.environ["KERAS_BACKEND"]="jax"#或者"tensorflow"、"torch"。
#在使用JAX后端时,避免内存碎片化。
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

导入包

导入 Keras 和 KerasNLP。

importkeras
importkeras_nlp

加载数据集

预处理数据是微调模型的重要步骤,尤其是当使用大型语言模型时。本教程使用的是1000个训练示例的子集,以便更快地执行。如果想要获得更高质量的微调效果,建议使用更多的训练数据。

importjson
data=[]
withopen('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl')asfile:
forlineinfile:
features=json.loads(line)
#过滤掉带有上下文的示例,以保持简单。
iffeatures["context"]:
continue
#将整个示例格式化为单个字符串。
template="Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))

#仅使用1000个训练示例,以保持快速。
data=data[:1000]

加载模型

KerasNLP 提供了许多流行模型架构的实现。在本教程中,您将使用GemmaCausalLM创建一个模型,这是一个用于因果语言建模的端到端 Gemma 模型。因果语言模型基于前面的令牌预测下一个令牌。

使用from_preset方法创建模型:

gemma_lm=keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()
Preprocessor:"gemma_causal_lm_preprocessor"
Tokenizer (type)Vocab #
gemma_tokenizer (GemmaTokenizer)256,000
Model:"gemma_causal_lm"
Layer (type)Output ShapeParam #Connected to
padding_mask (InputLayer)(None, None)0-
token_ids (InputLayer)(None, None)0-
gemma_backbone (GemmaBackbone)(None, None, 2048)2,506,172,416padding_mask[0][0], token_ids[0][0]
token_embedding (ReversibleEmbedding)(None, None, 256000)524,288,000gemma_backbone[0][0]
Totalparams:2,506,172,416(9.34GB)
Trainableparams:2,506,172,416(9.34GB)
Non-trainableparams:0(0.00B)

from_preset方法从预设的架构和权重中实例化模型。在上述代码中,字符串 "gemma_2b_en" 指定了预设的架构 —— 一个拥有 20 亿参数的 Gemma 模型。

注意:Gemma 也提供了一个有 70 亿参数的模型。要在 Colab 中运行更大的模型,您需要访问付费计划中提供的高级 GPU。或者,您可以在 Kaggle 或 Google Cloud 上对 Gemma 7B 模型进行分布式调优。

在微调之前的推理

在本节中,您将用各种提示查询模型,以查看其如何响应。

欧洲旅行提示

查询模型以获取关于欧洲旅行应做些什么的建议。

prompt=template.format(
instruction="WhatshouldIdoonatriptoEurope?",
response="",
)
print(gemma_lm.generate(prompt,max_length=256))
Instruction:
WhatshouldIdoonatriptoEurope?

Response:
1.TakeatriptoEurope.
2.TakeatriptoEurope.
3.TakeatriptoEurope.
4.TakeatriptoEurope.
5.TakeatriptoEurope.
6.TakeatriptoEurope.
7.TakeatriptoEurope.
8.TakeatriptoEurope.
9.TakeatriptoEurope.
10.TakeatriptoEurope.
11.TakeatriptoEurope.
12.TakeatriptoEurope.
13.TakeatriptoEurope.
14.TakeatriptoEurope.
15.TakeatriptoEurope.
16.TakeatriptoEurope.
17.TakeatriptoEurope.
18.TakeatriptoEurope.
19.TakeatriptoEurope.
20.TakeatriptoEurope.
21.TakeatriptoEurope.
22.TakeatriptoEurope.
23.TakeatriptoEurope.
24.TakeatriptoEurope.
25.Takeatripto

该模型只是重复打印“Take a trip to Europe”。

ELI5 光合作用提示

提示模型用 5 岁儿童能够理解的简单术语解释光合作用。

prompt=template.format(
instruction="Explaintheprocessofphotosynthesisinawaythatachildcouldunderstand.",
response="",
)
print(gemma_lm.generate(prompt,max_length=256))
Instruction:
Explaintheprocessofphotosynthesisinawaythatachildcouldunderstand.

Response:
Photosynthesisistheprocessbywhichplantsusetheenergyfromthesuntoconvertwaterandcarbondioxideintooxygenandglucose.Theprocessbeginswiththeabsorptionoflightenergybychlorophyllmoleculesintheleavesofplants.Theenergyfromthelightisusedtosplitwatermoleculesintohydrogenandoxygen.Theoxygenisreleasedintotheatmosphere,whilethehydrogenisusedtomakeglucose.Theglucoseisthenusedbytheplanttomakeenergyandgrow.

Explanation:
Photosynthesisistheprocessbywhichplantsusetheenergyfromthesuntoconvertwaterandcarbondioxideintooxygenandglucose.Theprocessbeginswiththeabsorptionoflightenergybychlorophyllmoleculesintheleavesofplants.Theenergyfromthelightisusedtosplitwatermoleculesintohydrogenandoxygen.Theoxygenisreleasedintotheatmosphere,whilethehydrogenisusedtomakeglucose.Theglucoseisthenusedbytheplanttomakeenergyandgrow.

Explanation:

Photosynthesisistheprocessbywhichplantsusetheenergyfromthesuntoconvertwaterandcarbondioxideintooxygenandglucose.Theprocessbeginswiththeabsorptionoflightenergybychlorophyllmoleculesintheleavesofplants.Theenergyfrom

回答中包含对儿童来说可能不容易理解的单词,例如叶绿素、葡萄糖等。

LoRA 微调

要从模型中获得更好的响应,可以使用 Databricks Dolly 15k 数据集通过低秩适应(LoRA)对模型进行微调。

LoRA 秩决定了添加到 LLM 原始权重中的可训练矩阵的维度。它控制着微调调整的表达性和精度。

更高的秩意味着可以进行更详细的更改,但也意味着有更多的可训练参数。较低的秩意味着计算开销较小,但可能导致适应性不够精确。

本教程使用的 LoRA 秩为 4。在实践中,从相对较小的秩开始(例如 4、8、16)是计算上高效的试验方法。使用这个秩训练您的模型,并评估在您的任务上的性能改进。逐渐增加后续试验的秩,看看是否能进一步提高性能。

#为模型启用LoRA并将LoRA秩设置为4。
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor:"gemma_causal_lm_preprocessor"
Tokenizer (type)Vocab #
gemma_tokenizer (GemmaTokenizer)256,000
Model:"gemma_causal_lm"
Layer (type)Output ShapeParam #Connected to
padding_mask (InputLayer)(None, None)0-
token_ids (InputLayer)(None, None)0-
gemma_backbone (GemmaBackbone)(None, None, 2048)2,507,536,384padding_mask[0][0], token_ids[0][0]
token_embedding (ReversibleEmbedding)(None, None, 256000)524,288,000gemma_backbone[0][0]
Totalparams:2,507,536,384(9.34GB)
Trainableparams:1,363,968(5.20MB)
Non-trainableparams:2,506,172,416(9.34GB)

请注意,启用 LoRA 会显着减少可训练参数的数量(从 25 亿减少到 130 万)。

#将输入序列长度限制为512(以控制内存使用)。
gemma_lm.preprocessor.sequence_length=512
#使用AdamW(transformer模型的常见优化器)。
optimizer=keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
#从衰减(decay)中排除layernorm和偏置项。
optimizer.exclude_from_weight_decay(var_names=["bias","scale"])

gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data,epochs=1,batch_size=1)

微调之后的推理

微调后,模型的响应会遵循提示中提供的指令。

欧洲旅行提示

prompt=template.format(
instruction="WhatshouldIdoonatriptoEurope?",
response="",
)
print(gemma_lm.generate(prompt,max_length=256))
Instruction:
WhatshouldIdoonatriptoEurope?

Response:
YoushouldplantoseethemostfamoussightsinEurope.TheEiffelTower,theAcropolis,andtheColosseumarejustafew.Youshouldalsoplanonseeingasmanycountriesaspossible.TherearesomanyamazingplacesinEurope,itisashametonotseethemall.

AdditionalInformation:
Europeisaveryinterestingplacetovisitformanyreasons,notleastofwhichisthattherearesomanydifferentplacestosee.

微调后的模型现在可以推荐在欧洲访问的地方了。

ELI5 光合作用提示

prompt=template.format(
instruction="Explaintheprocessofphotosynthesisinawaythatachildcouldunderstand.",
response="",
)
print(gemma_lm.generate(prompt,max_length=256))
Instruction:
Explaintheprocessofphotosynthesisinawaythatachildcouldunderstand.

Response:
Photosynthesisisaprocessinwhichplantsandphotosyntheticorganisms(suchasalgae,cyanobacteria,andsomebacteriaandarchaea)uselightenergytoconvertwaterandcarbondioxideintosugarandreleaseoxygen.Thisprocessrequireschlorophyll,water,carbondioxide,andenergy.Thechlorophyllcapturesthelightenergyandusesittopowerareactionthatconvertsthecarbonfromcarbondioxideintoorganicmolecules(suchassugar)thatcanbeusedforenergy.Theprocessalsogeneratesoxygenasaby-product.

该模型现在用简单的术语解释了光合作用。

请注意,出于演示目的,本教程仅在数据集的小子集上对模型进行了一次迭代(epoch)的微调,并且使用了较低的 LoRA 秩值。要从微调后的模型中获得更好的响应,您可以尝试:

  1. 1.增加微调数据集的大小。

  2. 2.增加训练步骤(迭代次数)。

  3. 3.设置更高的 LoRA 秩。

  4. 4.修改超参数值,如学习率(learning_rate)和权重衰减(weight_decay)。


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ