返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

小白学RAG:FlagEmbedding重排序

[复制链接]
链载Ai 显示全部楼层 发表于 13 小时前 |阅读模式 打印 上一主题 下一主题

RAG模型已经取得了显著的进展,但其性能仍然受到排序质量的限制。在实践中,我们发现重排序技术能够有效地改善排序的效果,从而进一步提升RAG模型在问答任务中的表现。

重排序的作用

与传统的嵌入模型不同,重排序器(reranker)直接以问题和文档作为输入,并输出相似度,而不是嵌入。通过将查询和文段输入到重排序器中,您可以获得相关性分数。重排序器基于交叉熵损失进行优化,因此相关性分数不限于特定范围。

unsetunsetFlagEmbeddingunsetunset

https://github.com/FlagOpen/FlagEmbedding

在FlagEmbedding中,重点放在了检索增强的语言模型上,目前包括以下项目:

  • 语言模型微调(Fine-tuning of LM)
  • 嵌入模型(Embedding Model)
  • 重排序模型(Reranker Model)
pipinstall-UFlagEmbedding

如下是一个rerank的例子,获取相关性分数(较高的分数表示更高的相关性):

fromFlagEmbeddingimportFlagReranker

#初始化重排序器,您可以选择是否启用混合精度以加快计算速度
reranker=FlagReranker('BAAI/bge-reranker-large',use_fp16=True)

#单个查询和文段的相关性分数
score=reranker.compute_score(['query','passage'])
print(score)

#批量查询和文段的相关性分数
scores=reranker.compute_score([['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']])
print(scores)

也可以直接通过transformers使用:

importtorch
fromtransformersimportAutoModelForSequenceClassification,AutoTokenizer

tokenizer=AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model=AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

pairs=[['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']]
withtorch.no_grad():
inputs=tokenizer(pairs,padding=True,truncation=True,return_tensors='pt',max_length=512)
scores=model(**inputs,return_dict=True).logits.view(-1,).float()
print(scores)

unsetunsetReRank原理unsetunset

跨编码器(Cross-encoder)采用全注意力机制对输入文本对进行完整关注,相比嵌入模型(例如双编码器),其精度更高,但也更耗时。因此,它可以用于重新排列嵌入模型返回的前k个文档,从而提高排序的准确性。

跨编码器的引入为多语言文本检索任务带来了更加准确和高效的解决方案。通过在全文本对上进行完整关注,跨编码器能够更好地理解语义信息,从而提高了排序的精度。

基础ReRank模型

基础的ReRank模型本质是一个序列分类模型,旨在对输入的文本对进行分类。它基于预训练的Transformer架构,通常是BERT、RoBERTa或类似的模型,其中编码器部分对输入文本进行编码,而解码器部分则用于分类任务。

fromFlagEmbeddingimportFlagReranker
reranker=FlagReranker('BAAI/bge-reranker-v2-m3',use_fp16=True)#Settinguse_fp16toTruespeedsupcomputationwithaslightperformancedegradation

score=reranker.compute_score(['query','passage'])
print(score)#-5.65234375

#Youcanmapthescoresinto0-1byset"normalize=True",whichwillapplysigmoidfunctiontothescore
score=reranker.compute_score(['query','passage'],normalize=True)
print(score)#0.003497010252573502

scores=reranker.compute_score([['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']])
print(scores)#[-8.1875,5.26171875]

#Youcanmapthescoresinto0-1byset"normalize=True",whichwillapplysigmoidfunctiontothescore
scores=reranker.compute_score([['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']],normalize=True)
print(scores)#[0.00027803096387751553,0.9948403768236574]

LLM-based reranker

LLM-based reranker的思路,其主要思想是利用预训练的语言模型来对查询和文档进行编码,并根据编码结果生成相应的分数,以评估它们之间的相关性。

预设定的提示词如下:

Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.

fromFlagEmbeddingimportFlagLLMReranker
reranker=FlagLLMReranker('BAAI/bge-reranker-v2-gemma',use_fp16=True)#Settinguse_fp16toTruespeedsupcomputationwithaslightperformancedegradation
#reranker=FlagLLMReranker('BAAI/bge-reranker-v2-gemma',use_bf16=True)#Youcanalsosetuse_bf16=Truetospeedupcomputationwithaslightperformancedegradation

score=reranker.compute_score(['query','passage'])
print(score)

scores=reranker.compute_score([['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']])
print(scores)

LLM-based layerwise reranker

based layerwise reranker是一种基于语言模型的重排序器,它允许您选择特定层的输出来计算分数,以加速推断过程并适应多语言环境。

fromFlagEmbeddingimportLayerWiseFlagLLMReranker
reranker=LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise',use_fp16=True)#Settinguse_fp16toTruespeedsupcomputationwithaslightperformancedegradation
#reranker=LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise',use_bf16=True)#Youcanalsosetuse_bf16=Truetospeedupcomputationwithaslightperformancedegradation

score=reranker.compute_score(['query','passage'],cutoff_layers=[28])#Adjusting'cutoff_layers'topickwhichlayersareusedforcomputingthescore.
print(score)

scores=reranker.compute_score([['whatispanda?','hi'],['whatispanda?','Thegiantpanda(Ailuropodamelanoleuca),sometimescalledapandabearorsimplypanda,isabearspeciesendemictoChina.']],cutoff_layers=[28])
print(scores)

unsetunset模型精度对比unsetunset



回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ