返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

手把手教程,改造 GraphRAG 支持自定义 LLM

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 10:47 |阅读模式 打印 上一主题 下一主题

导语:最近 GraphRAG 在社区很火,作者亲自体验后,发现了一些可以探讨和改进的地方,本文主要介绍了如何改造 GraphRAG 以支持自定义的 LLM。

01

为什么在 RAG 中引入知识图谱?

传统的 RAG 在处理复杂问题时往往表现不理想,主要是传统 RAG 未能有效捕捉实体间的复杂关系和层次结构,且通常只检索固定数量的最相关文本块:

  • 缺少事情之间关系的理解:当需要关联不同信息以提供综合见解时,传统 RAG 很难将这些点连接起来。
  • ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;">缺乏整体视角:当要求 RAG 全面理解大型数据集甚至单个大型文档的整体语义概念时,缺乏宏观视角,例如,当给它一本小说并问它“这本书的主旨是什么”时,十有八九会给不出靠谱的答案。
这个问题在我们上一篇文章《为什么说知识图谱 + RAG > 传统 RAG?》也有详细分析,感兴趣可以点击上面的链接查看。
微软的 GraphRAG 通过引入知识图谱来解决传统 RAG 的局限性,在索引数据集时,GraphRAG 提取实体和实体间的关系,构建知识图谱,这让 GraphRAG 能够更全面地理解文档的语义,捕捉实体间的复杂关联,从而在处理复杂查询时表现出色。
  • 这种方法适合处理需要对整个数据集进行综合理解的问题,如“数据集中的主要主题是什么?”这类问题;
  • 相比传统的 RAG 方法,Graph RAG 在处理全局性问题时表现出更好;

02

GraphRAG 改造计划

设计的理念很不错,但是真的去体验使用的时候,发现几个问题:

  1. 强依赖于 OpenAI 或 Azure 的服务。对于国内用户来说,OpenAI 的 key 还是需要国外银行卡,Azure 的 API 申请也比较繁琐,还有国外的云一般都是绑定信用卡,可能不小心用超了,上次体验 AWS 的产品,忘了删除了,后面发现扣了我快 1000 块钱,我只是体验下产品而已...
  2. ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;color: rgb(86, 86, 86);font-size: 15px;letter-spacing: 1px;">GraphRAG 目前更像是一个 Demo 产品,想和业务结合现在也没什么可以操作的地方,肯定是需要自定义的。

想着能不能让GraphRAG 集成到业务中,准备对 GraphRAG 做一些改造,主要从以下几个方向进行:
  1. 支持自定义 LLM,OpenAI 也比较贵,换成一些更便宜的模型。我首先选择了自家的 Qwen 模型,大家可以在我的基础上扩展其他模型的支持。Qwen 默认给 50W 的 Token 使用量,够玩一段时间的,而且可以用更便宜的 turbo 模型;
  2. 支持自定义向量数据库,方便线上使用;
  3. 引入一些业务属性,看看如何能和业务结合在一起;
  4. 优化下使用体验,实现生成的知识图谱可视化。

这篇文章我会首先介绍下如何改造 GraphRAG 以支持自定义的 LLM,同时我把修改 GraphRAG 的代码也开源在 GitHub 上了,也欢迎感兴趣的朋友共同建设...

03

环境准备

3.1安装依赖

因为我们是修改 GraphRAG 的代码,就不从 pip 进行安装了,另外对版本有一定的要求:
  • Python 3.10 ~ 3.12版本
gitclonegit@github.com:microsoft/graphrag.git
安装 poetry:
#先安装pipx
brewinstallpipx
pipxensurepath
sudopipxensurepath--global#optionaltoallowpipxactionsinglobalscope.See"Globalinstallation"sectionbelow.

#安装poetry
pipxinstallpoetry
poetrycompletionszsh>~/.zfunc/_poetry
mkdir$ZSH_CUSTOM/plugins/poetry
poetrycompletionszsh>$ZSH_CUSTOM/plugins/poetry/_poetry

在 graphrag仓库下安装依赖:
poetryinstall
另外在 PyCharm 中安装下 BigData 的文件预览插件,可以看到 index 过程中的文件结构类型:

3.2 项目结构

graphrag 是 GraphRAG 项目的核心包,包含了所有的关键代码逻辑。下面有几个重要的子目录,每个目录负责不同的功能模块:
  • config 目录:存储 GraphRAG 配置后的对象,在 GraphRAG 启动时,会读取配置文件,并将配置解析为 config 目录下的各种对象;
  • index 目录:核心包,所有索引相关的核心逻辑;
  • query 目录:核心包,查询相关的类和逻辑,当用户提交查询请求时,query 目录下的代码会负责解析查询、检索知识图谱、生成回答等一系列操作;
  • model 目录:核心领域模型,如文本、文档、主题、关系等,GraphRAG 中的核心概念和数据结构,其他模块都围绕着这些模型进行操作和处理;
  • llm 目录:支持的 LLM 的实现。如果要自定义集成通义千问,就需要在这个目录下进行实现;
  • vector_stores 目录:包含向量数据库的实现。如果要自定义向量存储,需要在这个目录下进行实现。

3.3运行& Debug 项目

不同于 pip 的包安装,这里我们要在 pycharm 里面配置下如何从代码的形式运行项目的内容,官方入门给的几个案例,我们通过代码的形式运行:
mkdir-p./ragtest/input

#这一步可以随便替换成一些其他的文档,小一点的,这样效率比较开,可以更快的验证下我们的改造结果
curlhttps://www.gutenberg.org/cache/epub/24022/pg24022.txt>./ragtest/input/book.txt

初始化项目:

python-mgraphrag.index--init--root./ragtest

对文档进行索引:

python-mgraphrag.index--root./ragtest

进行本地查询:

python-mgraphrag.query\
--root./ragtest\
--methodlocal\
"WhoisScrooge,andwhatarehismainrelationships?"
如果直接运行上面的命令,会发现无法运行,让我们配置一下:
  • 运行方式选择模块运行;
  • 模块后面参考上述官方的命令,给出的具体模块;
  • 接下来填具体的参数,还有工作目录不要忘了。

在上述配置完成之后,你就可以 debug 项目,一步步了解项目中内部的各种细节了,模块的入口类在包下的 __main__.py 文件中。

04

GraphRAG 支持通义千问

4.1修改的内容

1、项目中默认支持的 LLM 类型是没有通义千问的,因此在枚举类型上要支持通义千问;
2、在进行 index 的时候,会有一步 load_llm 的操作,我们在配置文件中定义的千问类型,在 load_llm 中实现,兼容下原本的接口。
3、在查询的时候,默认使用 OpenAI 的客户端,判断下配置文件的类型,如果是 qwen 的类型,使用我们自己的千问实现。
项目中的 index 和 query 的 llm 是两套不同的视线,我觉得其实可以合并在一起的,不过为了先走通,就是在 index 和 query 都实现了一遍。
核心是在 llm 目录下新增了一个 qwen 的包;在 query 的 llm/qwen 目录下新增了 qwen 的问答实现。

4.2 支持 Qwen 类型的配置

在 config 的 enums 中增加下千问的几个枚举,不然直接在配置文件中写 qwen 会报类型无法转换错误。

4.3使用 Qwen 进行 Index

在 index 的时候,执行逻辑会走到 load_llm,在加载 llm 的部分,支持下 QwenLLM 的实现。
然后实现对应的方法和类,我再给出我们的 QwenCompletionLLM 以及
def_load_qwen_llm(
on_error:ErrorHandlerFn,
cacheLMCache,
config:dict[str,Any],
azure=False,
):
log.info(f"LoadingQwencompletionLLMwithconfig{config}")
returnQwenCompletionLLM(config)

def_load_qwen_embeddings_llm(
on_error:ErrorHandlerFn,
cacheLMCache,
config:dict[str,Any],
azure=False,
):
log.info(f"LoadingQwenembeddingsLLMwithconfig{config}")
returnDashscopeEmbeddingsLLM(config);

通过兼容原本的方法,到这里索引部分就可以通过 Qwen 完全进行使用了。


#Copyright(c)2024MicrosoftCorporation.
#LicensedundertheMITLicense

importasyncio
importjson
importlogging
fromhttpimportHTTPStatus
fromtypingimportUnpack,List,Dict

importdashscope
importregexasre

fromgraphrag.configimportLLMType
fromgraphrag.llmimportLLMOutput
fromgraphrag.llm.baseimportBaseLLM
fromgraphrag.llm.base.base_llmimportTIn,TOut
fromgraphrag.llm.typesimport(
CompletionInput,
CompletionOutput,
LLMInput,
)

log=logging.getLogger(__name__)


classQwenCompletionLLM(
BaseLLM[
CompletionInput,
CompletionOutput,
]
):
def__init__(self,llm_config:dict=None):
log.info(f"llm_config:{llm_config}")
self.llm_config=llm_configor{}
self.api_key=self.llm_config.get("api_key","")
self.model=self.llm_config.get("model",dashscope.Generation.Models.qwen_turbo)
#self.chat_mode=self.llm_config.get("chat_mode",False)
self.llm_type=llm_config.get("type",LLMType.StaticResponse)
self.chat_mode=(llm_config.get("type",LLMType.StaticResponse)==LLMType.QwenChat)

asyncdef_execute_llm(
self,
input:CompletionInput,
**kwargs:Unpack[LLMInput],
)->CompletionOutput:
log.info(f"input:{input}")
log.info(f"kwargs:{kwargs}")

variables=kwargs.get("variables",{})

#使用字符串替换功能替换占位符
formatted_input=replace_placeholders(input,variables)

ifself.chat_mode:
history=kwargs.get("history",[])
messages=[
*history,
{"role":"user","content":formatted_input},
]
response=self.call_with_messages(messages)
else:
response=self.call_with_prompt(formatted_input)

ifresponse.status_code==HTTPStatus.OK:
ifself.chat_mode:
returnresponse.output["choices"][0]["message"]["content"]
else:
returnresponse.output["text"]
else:
raiseException(f"Error{response.code}:{response.message}")

defcall_with_prompt(self,query:str):
print("call_with_prompt{}".format(query))
response=dashscope.Generation.call(
model=self.model,
prompt=query,
api_key=self.api_key
)
returnresponse

defcall_with_messages(self,messages:list[dict[str,str]]):
print("call_with_messages{}".format(messages))
response=dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
result_format='message',
)
returnresponse

#主函数
asyncdef_invoke_json(self,input:TIn,**kwargs)->LLMOutput[TOut]:
try:
output=awaitself._execute_llm(input,**kwargs)
exceptExceptionase:
print(f"ErrorexecutingLLM:{e}")
returnLLMOutput[TOut](output=None,json=None)

#解析output的内容
extracted_jsons=extract_json_strings(output)

iflen(extracted_jsons)>0:
json_data=extracted_jsons[0]
else:
json_data=None

try:
output_str=json.dumps(json_data)
except(TypeError,ValueError)ase:
print(f"ErrorserializingJSON:{e}")
output_str=None

returnLLMOutput[TOut](
output=output_str,
json=json_data
)


defreplace_placeholders(input_str,variables):
forkey,valueinvariables.items():
placeholder="{"+key+"}"
input_str=input_str.replace(placeholder,value)
returninput_str


defpreprocess_input(input_str):
#预处理输入字符串,移除或转义特殊字符
returninput_str.replace('<','<').replace('>','>')


defextract_json_strings(input_string:str)->List[Dict]:
#正则表达式模式,用于匹配JSON对象
json_pattern=re.compile(r'(\{(?:[^{}]|(?R))*\})')

#查找所有匹配的JSON子字符串
matches=json_pattern.findall(input_string)

json_objects=[]
formatchinmatches:
try:
#尝试解析JSON子字符串
json_object=json.loads(match)
json_objects.append(json_object)
exceptjson.JSONDecodeError:
#如果解析失败,忽略此子字符串
log.warning(f"InvalidJSONstring:{match}")
pass

returnjson_objects


实现下对应的 Embeding 模型;

"""TheEmbeddingsLLMclass."""
importlogging

log=logging.getLogger(__name__)

fromtypingimportUnpack
fromgraphrag.llm.baseimportBaseLLM
fromgraphrag.llm.typesimport(
EmbeddingInput,
EmbeddingOutput,
LLMInput,
)

fromhttpimportHTTPStatus
importdashscope
importlogging

log=logging.getLogger(__name__)


classQwenEmbeddingsLLM(BaseLLM[EmbeddingInput,EmbeddingOutput]):
"""Atext-embeddinggeneratorLLMusingDashscope'sAPI."""

def__init__(self,llm_config:dict=None):
log.info(f"llm_config:{llm_config}")
self.llm_config=llm_configor{}
self.api_key=self.llm_config.get("api_key","")
self.model=self.llm_config.get("model",dashscope.TextEmbedding.Models.text_embedding_v1)

asyncdef_execute_llm(
self,input:EmbeddingInput,**kwargs:Unpack[LLMInput]
)->EmbeddingOutput:
log.info(f"input:{input}")

response=dashscope.TextEmbedding.call(
model=self.model,
input=input,
api_key=self.api_key
)

ifresponse.status_code==HTTPStatus.OK:
res=[embedding["embedding"]forembeddinginresponse.output["embeddings"]]
returnres
else:
raiseException(f"Error{response.code}:{response.message}")
通过刚才我们配置的运行方式,配置下 Qwen,运行下,然后可以通过 BigData Viwer 看到里面的内容:
在 indexing-engine.log 里面也可以看到详细的内容

4.4使用 Qwen 进行 Query

在 GraphRAG 中,query 和 index 用的是不同的 BaseLLM 抽象类,并且在 Query 这里默认用的 OpenAIEmbeding,这里我们也修改一下。
  • query 相比 index 支持了流式的输出内容:

Qwen 的 Query 问答实现:
importasyncio
importlogging
fromhttpimportHTTPStatus
fromtypingimportAny

importdashscope
fromtenacityimport(
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)

fromgraphrag.query.llm.baseimportBaseLLMCallback,BaseLLM
fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter

log=logging.getLogger(__name__)

classDashscopeGenerationLLM(BaseLLM):
def__init__(
self,
api_key:str|None=None,
model:str|None=None,
max_retries:int=10,
request_timeout:float=180.0,
retry_error_types:tuple[type[BaseException]]=(Exception,),
reporter:StatusReporter=ConsoleStatusReporter(),
):
self.api_key=api_key
self.model=modelordashscope.Generation.Models.qwen_turbo
self.max_retries=max_retries
self.request_timeout=request_timeout
self.retry_error_types=retry_error_types
self._reporter=reporter

defgenerate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
returnself._generate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
exceptRetryErrorase:
self._reporter.error(
message="Erroratgenerate()",details={self.__class__.__name__:str(e)}
)
return""
else:
return""

asyncdefagenerate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
returnawaitasyncio.to_thread(
self._generate,
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
exceptRetryErrorase:
self._reporter.error(f"Erroratagenerate():{e}")
return""
else:
return""

def_generate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
ifisinstance(messages,list):
response=dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
result_format='message',
**kwargs,
)
else:
response=dashscope.Generation.call(
model=self.model,
prompt=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
**kwargs,
)

#ifresponse.status_code!=HTTPStatus.OK:
#raiseException(f"Error{response.code}:{response.message}")

ifstreaming:
full_response=""
forchunkinresponse:
ifchunk.status_code!=HTTPStatus.OK:
raiseException(f"Error{chunk.code}:{chunk.message}")

decoded_chunk=chunk.output.choices[0]['message']['content']
full_response+=decoded_chunk
ifcallbacks:
forcallbackincallbacks:
callback.on_llm_new_token(decoded_chunk)
returnfull_response
else:
ifisinstance(messages,list):
returnresponse.output["choices"][0]["message"]["content"]
else:
returnresponse.output["text"]

实现 Query 的 Embedding 对象:

importasyncio
importlogging
fromtypingimportAny

importdashscope
fromtenacityimport(
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)

fromgraphrag.query.llm.baseimportBaseTextEmbedding
fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter

log=logging.getLogger(__name__)


classDashscopeEmbedding(BaseTextEmbedding):

def__init__(
self,
api_key:str|None=None,
model:str=dashscope.TextEmbedding.Models.text_embedding_v1,
max_retries:int=10,
retry_error_types:tuple[type[BaseException]]=(Exception,),
reporter:StatusReporter=ConsoleStatusReporter(),
):
self.api_key=api_key
self.model=model
self.max_retries=max_retries
self.retry_error_types=retry_error_types
self._reporter=reporter

defembed(self,text:str,**kwargs:Any)->list[float]:
try:
embedding=self._embed_with_retry(text,**kwargs)
returnembedding
exceptExceptionase:
self._reporter.error(
message="Errorembeddingtext",
details={self.__class__.__name__:str(e)},
)
return[]

asyncdefaembed(self,text:str,**kwargs:Any)->list[float]:
try:
embedding=awaitasyncio.to_thread(self._embed_with_retry,text,**kwargs)
returnembedding
exceptExceptionase:
self._reporter.error(
message="Errorembeddingtextasynchronously",
details={self.__class__.__name__:str(e)},
)
return[]

def_embed_with_retry(self,text:str,**kwargs:Any)->list[float]:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
response=dashscope.TextEmbedding.call(
model=self.model,
input=text,
api_key=self.api_key,
**kwargs,
)
ifresponse.status_code==200:
embedding=response.output["embeddings"][0]["embedding"]
returnembedding
else:
raiseException(f"Error{response.code}:{response.message}")
exceptRetryErrorase:
self._reporter.error(
message="Erroratembed_with_retry()",
details={self.__class__.__name__:str(e)},
)
return[]

运行下 Query 的效果:

可以看到使用的是 Qwen 的模型进行的问答:

4.5项目中的一些关键节点

创建工作流的地方,以及默认的工作流:

4.6遇到错误怎么办

上面已经提到了可以配置 pycharm 来配置 debug 断点;
在执行错误的时候,在 output 下面会有对应的执行详细信息,根据错误信息,可以在对应的地方加上断点查看错误的原因是什么。

05

GraphRAG 的核心步骤

参考论文《From Local to Global: A Graph RAG Approach to Query-Focused Summarization》的描述:https://arxiv.org/pdf/2404.16130
GraphRAG 的主要 Pipeline 步骤说明:
1. 文本分块 (Source Documents → Text Chunks),将源文档分割成较小的文本块,每块大约 600 个 token。块之间有 100 个 token 的重叠,以保持上下文连贯性;
2. 元素实例提取 (Text Chunks → Element Instances):使用 LLM 从每个文本块中提取实体、关系和声明。实体包括名称、类型和描述。关系包括源实体、目标实体和描述。使用多轮"gleaning"技术来提高提取质量;
  • "Gleaning" 是一种迭代式的信息提取方法。初始提取: LLM 首先对文本块进行一次实体和关系提取。评估:LLM 被要求评估是否所有实体都被提取出来了。迭代提取::如果 LLM 认为有遗漏,它会被提示进行额外的"gleaning"轮次,尝试提取之前可能遗漏的实体。多轮进行:这个过程可以重复多次,直到达到预设的最大轮次或 LLM 认为没有更多实体可提取。
create_final_entities.parquet
create_final_nodes.parquet
create_final_relationships.parquet
3. 元素摘要生成 (Element Instances → Element Summaries):将相同元素的多个实例合并成单一的描述性文本块。
4. 图社区检测 (Element Summaries → Graph Communities):将实体作为节点,关系作为边构建无向加权图。
  • 使用 Leiden 算法进行检测,得到层次化的社区结构,Leiden 算法帮助我们把大量的文本信息组织成有意义的群组,使得我们可以更容易地理解和处理这些信息。

5. 社区摘要生成 (Graph Communities → Community Summaries):
  • 为每个社区生成报告式摘要;
  • 对于叶子级社区,直接总结其包含的所有元素;
  • 对于高层社区,递归地利用子社区摘要。

6. 查询回答 (Community Summaries → Community Answers → Global Answer):
  • Community Summaries(社区摘要):预先生成的,包含了图中每个社区(即相关实体群组)的概要信息。它们存储了关于每个主题领域的关键信息,通过问题找到一些相关的主题(社区摘要);
  • Community Answers(社区回答):当收到用户查询时,系统会并行处理每个社区摘要,对每个社区摘要,系统会生成一个针对用户问题的部分答案,系统还会给每个部分答案评分,表示其对回答问题的相关性;
  • Global Answer:系统会收集所有有用的部分答案(过滤掉评分为0的答案),然后,它会按照相关性评分对这些答案进行排序。最后,系统会综合这些部分答案,生成一个全面、连贯的最终答案。

06

小结


这里主要想和大家分享下如何定制下 GraphRAG 支持千问模型,方便更多的同学体验下 GraphRAG,当然这还只是第一步,GraphRAG 还不能直接应用到真实的场景中。
官网上有一些架构的设计理念和过程,也可以参考学习,看到 GraphRAG 的社区热度也挺高的,估计很快就可以作为一个相对成熟的方案引入到实际的系统中。
为了方便学习,我把上述的改动代码上传到了代码仓库,感兴趣的同学可以试一下,也可以继续进行定制和优化...
代码仓库地址:https://code.alibaba-inc.com/aihehe.ah/biz_graphrag
接下来尝试下怎么支持自定义的向量存储以及调研下能否和业务集成,比如:
  • 自定义 VectorStore 实现
  • GraphRAG 可视化过程
  • 业务集成GraphRAG


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ