- 代码:https://github.com/AkariAsai/self-rag
- 论文:https://arxiv.org/abs/2310.11511
现有问题:LLM的事实不准确性
LLM经常产生幻觉,特别是在长尾情况下,它们的知识变得过时,并且缺乏归因。
检索增强生成是否是万能药?
传统RAG可以无差别地检索和合并一定数量的检索段落,无论检索是否必要或段落是否相关,都可能导致无用的生成。
1. Self-RAG?
自我反思检索增强生成(Self-RAG),本质是微调两个大模型。一个是评估大模型(critic model),另一个是生成大模型(generator model)。微调的内容不是领域知识,而是作为RAG应用所应具备的技能,比如什么时候去检索、生成内容是否有幻觉、如何确保生成内容的真实可靠可用。我分别从模型训练和推理去介绍Self-RAG。
Critic模型训练
学习面对各种各样的query时,1.是否需要检索,2.检索知识是否和Query相关(准确性),3.基于检索生成的内容是否真的来自检索知识还是大模型自己yy的(事实支持性),以及4.检索生成的内容是否真的对用户Query有帮助(有用性)。训练critic模型的目标函数是最大化似然度, 其中 是数据集, 是自省标记(reflection tokens)。从公式可以看到,根据x和y来计算r的条件生成概率。而自省标记r包括4种: 粗体的文本表示最理想的自省标记。x、y、d分别表示输入、输出和相关文档段落。
Generator模型训练
有了Critic模型生成自省tokens能力作为基础,进一步构建增强(Augmented)训练数据,下面是构建流程,其中 表示critic模型, 是检索模块, 是相关文档段落。
- 用Critic模型判断x是否需要检索,并预测 [Retrieve] token值,并把值拼接到x后面;如果是Yes再通过检索模块找出 K 个最相关的文档段落集合 ;
- 对于每个段落,Critic模型会进一步评估段落和x是否相关,并预测 [IsREL] token值;如果段落是相关的,Critic模型又会进一步评估,段落是否能支持模型的生成,并预测 [IsSUP] token值;最后把这两个token值拼接在检索生成内容y后面
- 当整个y生成出来后,再预测y的 [IsUSE] token值,并把值拼接到y后,
下面是增强数据的样例 备注:文本chunk之间用 <p></p> 包住。
以此,生成整个数据集 ,并基于次数据集进行生成器模型训,目标函数即求x预测【y和r】的条件生成概率的最大对数似然估计 因为纪要预测y,也要预测自省标记r,因此需要将r扩进词表中。
汇总
整个Self-RAG的训练过程伪代码如下:
2. 推理流程
大致推理流程如上,我们展开描述一下:
如果需要检索,假设检索出的知识片段集合为 ,对于每个 ,
注意,在每个时间步都用LLM进行并行推理输出 个不同的 候选集,并且记录他们的得分 然后进行都进行Beam Search(设置Beam大小为),如下简图所示, 最终获取 个最优的候选片段序列 。 分数是 的加权,权重可以认为调整。而对应的自省token的得分也比较简单,看A.3附录即可。
- 基于三个自省标记( [IsREL]、[IsSUP]、[IsUSE] )的预测结果,对 进行打分
注意:Critic模型是不参与Self-RAG的推理,但它在训练阶段的作用是至关重要的。它确保了Generator模型能够学习到如何生成高质量的输出,并在需要时进行有效的自我评估和批判。
实践
论文也开源了微调模型,可以下载一个GGUF版本,并使用llama.cpp进行推理。先安装,
pip install llama_cpp_python pip install huggingface-hub
然后下载模型
huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir ./model --local-dir-use-symlinks False
给一个简单的运行示例
from llama_cpp import Llama
# 定义模型参数和生成参数 MODEL_KWARGS = { "logits_all": True, "n_ctx": 2048, "n_gpu_layers": 200 } GENERATE_KWARGS = { "temperature": 0.0, "top_p": 1.0, "max_tokens": 1024, "logprobs": 1000 }
# 初始化模型 llm = Llama(model_path="selfrag_llama2_7b.q4_k_m.gguf", **MODEL_KWARGS)
# 格式化Prompt函数 def format_prompt(query, paragraph=None): """ 格式化查询为模型所需的prompt格式。 :param query: 输入的问题或指令。 :param paragraph: 可选的,与查询相关的段落信息,用于检索。 :return: 格式化后的prompt字符串。 """ prompt = "### Instruction:\n{0}\n\n### Response:\n".format(query) if paragraph: prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph) return prompt
# 测试问题 queries = [ "撰写一首表达对老师的感激之情的短诗", "简述一下人工智能在医疗领域的应用" ]
# 测试并打印结果 for query in queries: prompt = format_prompt(query) result = llm(prompt, **GENERATE_KWARGS) # 提取并打印生成的文本 generated_text = result["choices"][0]["text"] print("\nResponse:\n{0}".format(generated_text)) # 如果需要,打印详细信息 # print(result["choices"][0])
|