返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

开源 LLM 在 text-to-sql 任务的 baseline 效果介绍

[复制链接]
链载Ai 显示全部楼层 发表于 2 小时前 |阅读模式 打印 上一主题 下一主题

本文是对基于开源项目 DB-GPT-Hub 利用开源的 LLM 在 spider 数据集上的 text-to-sql 任务的 baseline 效果介绍。其中开源 LLM 包括 Llama2-7B-Chat、Llama2-13B-Chat、CodeLlama-7B-Instruct、CodeLlama-13B-Instruct、Baichuan2-7B-Chat、Baichuan2-13B-Chat、Qwen-7B-Chat、Qwen-14B-Chat、ChatGLM3-6b,对 spider 数据集进行基于 LoRA 和 QLoRA 的训练,在 spider 官方的评估集上评估其执行准确率。

  • 项目地址:https://github.com/eosphoros-ai/DB-GPT-Hub

具体实验结果和对应的各个模型如上表,其中 method 中的 base 是指未经训练直接用模型本身进行预测评估,lora,qlora 是指大模型基于 LoRA 和 QLoRA 方式的训练方式。EX 中的 easy、 medium、 hard、 extra 分别为评估集中四个难度等级数据的准确率,all 为在所有难度等级上的整体执行准确率。上述训练集均只采用了 Spider 官方的训练数据,且只训练 8 个 epoch。

总体而言,采用 LoRA 和 QLoRA 方法进行微调后的模型相比 base 模型有显著提升。尤其是 CodeLlama-13B-Instruct 模型采用 LoRA 微调后在该任务上效果最好,在四个难度上的执行准确率都有大幅度提高,整体执行准确率达到 0.746 ,是目前效果最佳的模型。随着模型规模的增加,效果也有所提升。例如 Llama2 从 7B 到 13B ,准确率有了约 4-5% 的绝对提升。CodeLlama 从 7B 到 13B 提升更加明显,整体准确率从 0.149 上升到 0.539。国产模型方面, Qwen 的表现效果最好,在 14B 级别时效果整体准确率效果可以达到 0.701。


01

微调训练


随着模型规模的扩大和采用 LoRA、QLoRA 等方法的提示学习,可以有效提升开源 LLM 在文本到 SQL 转换任务上的效果。

详细的训练参数以 Qwen-7B-Chat 进行 LoRA 训练为例,如下:

CUDA_VISIBLE_DEVICES=0pythondbgpt_hub/train/sft_train.py\--model_name_or_path/home/model_files/Qwen-7B-Chat\--do_train\--datasetexample_text2sql_train\--max_source_length2048\--max_target_length512\--templatechatml\--finetuning_typelora\--lora_rank64\--lora_alpha32\--lora_targetc_attn\--output_dirdbgpt_hub/output/adapter/qwen-7b-2048_epoch8_lora\--overwrite_cache\--overwrite_output_dir\--per_device_train_batch_size1\--gradient_accumulation_steps16\--lr_scheduler_typecosine_with_restarts\--logging_steps500\--save_steps2000\--learning_rate2e-4\--num_train_epochs8\--plot_loss\--bf16

如果是使用 QLoRA 方法训练的话,同样以 Qwen 模型为例,训练参数如下所示:(主要设置参数量化精度 quantization_bit 为 4)

CUDA_VISIBLE_DEVICES=0pythondbgpt_hub/train/sft_train.py\--model_name_or_path/home/model_files/Qwen-14B-Chat\--do_train\--datasetexample_text2sql_train\--max_source_length2048\--max_target_length512\--templatechatml\--quantization_bit4\--finetuning_typelora\--lora_rank64\--lora_alpha32\--lora_targetc_attn\--output_dirdbgpt_hub/output/adapter/qwen-14b-2048_epoch8_qlora\--overwrite_cache\--overwrite_output_dir\--per_device_train_batch_size1\--gradient_accumulation_steps16\--lr_scheduler_typecosine_with_restarts\--logging_steps500\--save_steps2000\--learning_rate2e-4\--num_train_epochs8\--plot_loss\--bf16

同时,DB-GPT-Hub 项目还发布了 pip 包,用来降低 Text2SQL 训练的门槛, 除了通过仓库中提供的脚本的方式进行微调之外,还可以使用项目提供的 Python 包进行微调。

安装方式。直接采用 pip 安装即可:

pipinstalldbgpt_hub

使用方式。微调代码相关如下:

from dbgpt_hub.data_process import preprocess_sft_datafrom dbgpt_hub.train import start_sftfrom dbgpt_hub.predict import start_predictfrom dbgpt_hub.eval import start_evaluate
data_folder = "dbgpt_hub/data"data_info = [{"data_source": "spider","train_file": ["train_spider.json", "train_others.json"],"dev_file": ["dev.json"],"tables_file": "tables.json","db_id_name": "db_id","is_multiple_turn": False,"train_output": "spider_train.json","dev_output": "spider_dev.json",}]
train_args = {"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf","do_train": True,"dataset": "example_text2sql_train","max_source_length": 2048,"max_target_length": 512,"finetuning_type": "lora","lora_target": "q_proj,v_proj","template": "llama2","lora_rank": 64,"lora_alpha": 32,"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora","overwrite_cache": True,"overwrite_output_dir": True,"per_device_train_batch_size": 1,"gradient_accumulation_steps": 16,"lr_scheduler_type": "cosine_with_restarts","logging_steps": 50,"save_steps": 2000,"learning_rate": 2e-4,"num_train_epochs": 8,"plot_loss": True,"bf16": True,}
predict_args = {"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf","template": "llama2","finetuning_type": "lora","checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora","predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json","predict_out_dir": "dbgpt_hub/output/","predicted_out_filename": "pred_sql.sql",}
evaluate_args ={"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql","gold": "./dbgpt_hub/data/eval_data/gold.txt","gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt","db": "./dbgpt_hub/data/spider/database","table": "./dbgpt_hub/data/eval_data/tables.json","table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json","etype": "exec","plug_value": True,"keep_distict": False,"progress_bar_for_each_datapoint": False,"natsql": False,}
preprocess_sft_data(data_folder = data_folder,data_info = data_info)
start_sft(train_args)start_predict(predict_args)start_evaluate(evaluate_args)



02

得分展示



为了进一步展示 DB-GPT-Hub 项目取得的模型基础实验进展,项目提供了查看所有的模型基线得分以及具体的单个数据集上的实验得分、单个模型的实验得分等。

比如:查看所有的模型基线得分:

fromdbgpt_hub.baselineimportshow_scoresshow_scores()

显示结果如下所示:会默认按照平均精度降序输出,目前最高的平均精度为 codellama-13b-instruct 模型使用 lora 方法在 spider 数据集上训练,EX 为 0.746。

比如:还可以查看在数据集 spider 上的所有实验基线得分:

fromdbgpt_hub.baselineimportshow_scoreshow_score(dataset="spider")

比如:查看在数据集 spider,模型 llama2-7b-chat 的实验基线得分:

fromdbgpt_hub.baselineimportshow_scoreshow_score(dataset="spider",model="llama2-7b-chat")

比如:查看在数据集 spider,模型为 llama2-7b-chat,微调方法为 lora 的实验得分:

fromdbgpt_hub.baselineimportshow_scoreshow_score(dataset="spider",model="llama2-7b-chat",method="lora")

fromdbgpt_hub.baselineimportshow_scoreshow_score(dataset="spider",model="llama2-7b-chat",method="lora",prompt="alpaca")

最后,DB-GPT-Hub 项目目前 star 800+,欢迎关注和共建~

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ