返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

聊聊GLM-4-9B开源模型的微调loss计算

[复制链接]
链载Ai 显示全部楼层 发表于 2 小时前 |阅读模式 打印 上一主题 下一主题

概述

Github官方地址:GLM-4[1]

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

微调

微调格式:

[
{
"messages":[
{
"role":"system",
"content":"<systemprompttext>",
"tools":[
{
"name":"<toolname>",
"args":{
"<argname>":"<argvalue>"
}
}
]
},
{
"role":"user",
"content":"<userprompttext>"
},
{
"role":"assistant",
"content":"<assistantresponsetext>"
},
{
"role":"user",
"content":"<userprompttext>"
},
{
"role":"assistant",
"content":"<assistantresponsetext>"
},
{
"role":"observation",
"content":"<observationprompttext>"
},
{
"role":"assistant",
"content":"<assistantresponseobservation>"
},
{
"role":"user",
"content":"<userprompttext>"
},
{
"role":"assistant",
"content":"<assistantresponsetext>"
}
]
}
]

微调源码地址:finetune.py[2]

Loss计算代码:

defprocess_batch(
batch:Mapping[str,Sequence],
tokenizerreTrainedTokenizer,
max_input_length:int,
max_output_length:int,
)->dict[str,list]:
batched_conv=batch['messages']
batched_input_ids=[]
batched_labels=[]
#batched_conv是一个数组
#conv是数组内的单个message
forconvinbatched_conv:
input_ids=[151331,151333]
loss_masks=[False,False]
#conv是数组内的单个message
#message是单个rolejson对象
formessageinconv:
message=process_message(message)
#设置mask掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
loss_mask_val=Falseifmessage['role']in('system','user','observation')elseTrue
#获取input文本的数字表示(ids)
new_input_ids=tokenizer.apply_chat_template([message],tokenize=True,return_dict=False)[0][2:]
#计算整句的mask
new_loss_masks=[loss_mask_val]*len(new_input_ids)
#拼接message中的每段json
input_ids+=new_input_ids
#拼接message中每段json对应的mask
loss_masks+=new_loss_masks
#追加结尾的tokenid
input_ids.append(tokenizer.eos_token_id)
loss_masks=[False,*loss_masks]
labels=[]
forinput_id,maskinzip(input_ids,loss_masks):
ifmask:
#添加到label,计算loss
labels.append(input_id)
else:
#-100不处理,即ignore_index
labels.append(-100)
max_length=max_input_length+max_output_length+1
#截断
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return{'input_ids':batched_input_ids,'labels':batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer,model=load_tokenizer_and_model(model_dir,peft_config=ft_config.peft_config)
data_manager=DataManager(data_dir,ft_config.data_config)
#数据集拆分遍历
train_dataset=data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:',train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md[3]

Reference
[1]

GLM-4: https://github.com/THUDM/GLM-4

[2]

finetune.py: https://github.com/THUDM/GLM-4/blob/main/finetune_demo/finetune.py

[3]

dataset_format.md: https://github.com/InternLM/xtuner/blob/main/docs/zh_cn/user_guides/dataset_format.md


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ