|
导语:最近 GraphRAG 在社区很火,作者亲自体验后,发现了一些可以探讨和改进的地方,本文主要介绍了如何改造 GraphRAG 以支持自定义的 LLM。 传统的 RAG 在处理复杂问题时往往表现不理想,主要是传统 RAG 未能有效捕捉实体间的复杂关系和层次结构,且通常只检索固定数量的最相关文本块:- 缺少事情之间关系的理解:当需要关联不同信息以提供综合见解时,传统 RAG 很难将这些点连接起来。
- ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;">缺乏整体视角:当要求 RAG 全面理解大型数据集甚至单个大型文档的整体语义概念时,缺乏宏观视角,例如,当给它一本小说并问它“这本书的主旨是什么”时,十有八九会给不出靠谱的答案。
这个问题在我们上一篇文章《为什么说知识图谱 + RAG > 传统 RAG?》也有详细分析,感兴趣可以点击上面的链接查看。微软的 GraphRAG 通过引入知识图谱来解决传统 RAG 的局限性,在索引数据集时,GraphRAG 提取实体和实体间的关系,构建知识图谱,这让 GraphRAG 能够更全面地理解文档的语义,捕捉实体间的复杂关联,从而在处理复杂查询时表现出色。
设计的理念很不错,但是真的去体验使用的时候,发现几个问题:- 强依赖于 OpenAI 或 Azure 的服务。对于国内用户来说,OpenAI 的 key 还是需要国外银行卡,Azure 的 API 申请也比较繁琐,还有国外的云一般都是绑定信用卡,可能不小心用超了,上次体验 AWS 的产品,忘了删除了,后面发现扣了我快 1000 块钱,我只是体验下产品而已...
ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;color: rgb(86, 86, 86);font-size: 15px;letter-spacing: 1px;">GraphRAG 目前更像是一个 Demo 产品,想和业务结合现在也没什么可以操作的地方,肯定是需要自定义的。 想着能不能让GraphRAG 集成到业务中,准备对 GraphRAG 做一些改造,主要从以下几个方向进行:- 支持自定义 LLM,OpenAI 也比较贵,换成一些更便宜的模型。我首先选择了自家的 Qwen 模型,大家可以在我的基础上扩展其他模型的支持。Qwen 默认给 50W 的 Token 使用量,够玩一段时间的,而且可以用更便宜的 turbo 模型;
这篇文章我会首先介绍下如何改造 GraphRAG 以支持自定义的 LLM,同时我把修改 GraphRAG 的代码也开源在 GitHub 上了,也欢迎感兴趣的朋友共同建设... 因为我们是修改 GraphRAG 的代码,就不从 pip 进行安装了,另外对版本有一定的要求:
gitclonegit@github.com:microsoft/graphrag.git
#先安装pipx brewinstallpipx pipxensurepath sudopipxensurepath--global#optionaltoallowpipxactionsinglobalscope.See"Globalinstallation"sectionbelow.
#安装poetry pipxinstallpoetry poetrycompletionszsh>~/.zfunc/_poetry mkdir$ZSH_CUSTOM/plugins/poetry poetrycompletionszsh>$ZSH_CUSTOM/plugins/poetry/_poetry
另外在 PyCharm 中安装下 BigData 的文件预览插件,可以看到 index 过程中的文件结构类型:graphrag 是 GraphRAG 项目的核心包,包含了所有的关键代码逻辑。下面有几个重要的子目录,每个目录负责不同的功能模块:不同于 pip 的包安装,这里我们要在 pycharm 里面配置下如何从代码的形式运行项目的内容,官方入门给的几个案例,我们通过代码的形式运行:mkdir-p./ragtest/input
#这一步可以随便替换成一些其他的文档,小一点的,这样效率比较开,可以更快的验证下我们的改造结果 curlhttps://www.gutenberg.org/cache/epub/24022/pg24022.txt>./ragtest/input/book.txt
初始化项目: python-mgraphrag.index--init--root./ragtest 对文档进行索引: python-mgraphrag.index--root./ragtest
进行本地查询: python-mgraphrag.query\ --root./ragtest\ --methodlocal\ "WhoisScrooge,andwhatarehismainrelationships?"
如果直接运行上面的命令,会发现无法运行,让我们配置一下:在上述配置完成之后,你就可以 debug 项目,一步步了解项目中内部的各种细节了,模块的入口类在包下的 __main__.py 文件中。1、项目中默认支持的 LLM 类型是没有通义千问的,因此在枚举类型上要支持通义千问;2、在进行 index 的时候,会有一步 load_llm 的操作,我们在配置文件中定义的千问类型,在 load_llm 中实现,兼容下原本的接口。3、在查询的时候,默认使用 OpenAI 的客户端,判断下配置文件的类型,如果是 qwen 的类型,使用我们自己的千问实现。项目中的 index 和 query 的 llm 是两套不同的视线,我觉得其实可以合并在一起的,不过为了先走通,就是在 index 和 query 都实现了一遍。核心是在 llm 目录下新增了一个 qwen 的包;在 query 的 llm/qwen 目录下新增了 qwen 的问答实现。在 config 的 enums 中增加下千问的几个枚举,不然直接在配置文件中写 qwen 会报类型无法转换错误。在 index 的时候,执行逻辑会走到 load_llm,在加载 llm 的部分,支持下 QwenLLM 的实现。然后实现对应的方法和类,我再给出我们的 QwenCompletionLLM 以及def_load_qwen_llm( on_error:ErrorHandlerFn, cache LMCache, config:dict[str,Any], azure=False, ): log.info(f"LoadingQwencompletionLLMwithconfig{config}") returnQwenCompletionLLM(config)
def_load_qwen_embeddings_llm( on_error:ErrorHandlerFn, cache LMCache, config:dict[str,Any], azure=False, ): log.info(f"LoadingQwenembeddingsLLMwithconfig{config}") returnDashscopeEmbeddingsLLM(config);
通过兼容原本的方法,到这里索引部分就可以通过 Qwen 完全进行使用了。
#Copyright(c)2024MicrosoftCorporation. #LicensedundertheMITLicense
importasyncio importjson importlogging fromhttpimportHTTPStatus fromtypingimportUnpack,List,Dict
importdashscope importregexasre
fromgraphrag.configimportLLMType fromgraphrag.llmimportLLMOutput fromgraphrag.llm.baseimportBaseLLM fromgraphrag.llm.base.base_llmimportTIn,TOut fromgraphrag.llm.typesimport( CompletionInput, CompletionOutput, LLMInput, )
log=logging.getLogger(__name__)
classQwenCompletionLLM( BaseLLM[ CompletionInput, CompletionOutput, ] ): def__init__(self,llm_config:dict=None): log.info(f"llm_config:{llm_config}") self.llm_config=llm_configor{} self.api_key=self.llm_config.get("api_key","") self.model=self.llm_config.get("model",dashscope.Generation.Models.qwen_turbo) #self.chat_mode=self.llm_config.get("chat_mode",False) self.llm_type=llm_config.get("type",LLMType.StaticResponse) self.chat_mode=(llm_config.get("type",LLMType.StaticResponse)==LLMType.QwenChat)
asyncdef_execute_llm( self, input:CompletionInput, **kwargs:Unpack[LLMInput], )->CompletionOutput: log.info(f"input:{input}") log.info(f"kwargs:{kwargs}")
variables=kwargs.get("variables",{})
#使用字符串替换功能替换占位符 formatted_input=replace_placeholders(input,variables)
ifself.chat_mode: history=kwargs.get("history",[]) messages=[ *history, {"role":"user","content":formatted_input}, ] response=self.call_with_messages(messages) else: response=self.call_with_prompt(formatted_input)
ifresponse.status_code==HTTPStatus.OK: ifself.chat_mode: returnresponse.output["choices"][0]["message"]["content"] else: returnresponse.output["text"] else: raiseException(f"Error{response.code}:{response.message}")
defcall_with_prompt(self,query:str): print("call_with_prompt{}".format(query)) response=dashscope.Generation.call( model=self.model, prompt=query, api_key=self.api_key ) returnresponse
defcall_with_messages(self,messages:list[dict[str,str]]): print("call_with_messages{}".format(messages)) response=dashscope.Generation.call( model=self.model, messages=messages, api_key=self.api_key, result_format='message', ) returnresponse
#主函数 asyncdef_invoke_json(self,input:TIn,**kwargs)->LLMOutput[TOut]: try: output=awaitself._execute_llm(input,**kwargs) exceptExceptionase: print(f"ErrorexecutingLLM:{e}") returnLLMOutput[TOut](output=None,json=None)
#解析output的内容 extracted_jsons=extract_json_strings(output)
iflen(extracted_jsons)>0: json_data=extracted_jsons[0] else: json_data=None
try: output_str=json.dumps(json_data) except(TypeError,ValueError)ase: print(f"ErrorserializingJSON:{e}") output_str=None
returnLLMOutput[TOut]( output=output_str, json=json_data )
defreplace_placeholders(input_str,variables): forkey,valueinvariables.items(): placeholder="{"+key+"}" input_str=input_str.replace(placeholder,value) returninput_str
defpreprocess_input(input_str): #预处理输入字符串,移除或转义特殊字符 returninput_str.replace('<','<').replace('>','>')
defextract_json_strings(input_string:str)->List[Dict]: #正则表达式模式,用于匹配JSON对象 json_pattern=re.compile(r'(\{(?:[^{}]|(?R))*\})')
#查找所有匹配的JSON子字符串 matches=json_pattern.findall(input_string)
json_objects=[] formatchinmatches: try: #尝试解析JSON子字符串 json_object=json.loads(match) json_objects.append(json_object) exceptjson.JSONDecodeError: #如果解析失败,忽略此子字符串 log.warning(f"InvalidJSONstring:{match}") pass
returnjson_objects
实现下对应的 Embeding 模型; """TheEmbeddingsLLMclass.""" importlogging
log=logging.getLogger(__name__)
fromtypingimportUnpack fromgraphrag.llm.baseimportBaseLLM fromgraphrag.llm.typesimport( EmbeddingInput, EmbeddingOutput, LLMInput, )
fromhttpimportHTTPStatus importdashscope importlogging
log=logging.getLogger(__name__)
classQwenEmbeddingsLLM(BaseLLM[EmbeddingInput,EmbeddingOutput]): """Atext-embeddinggeneratorLLMusingDashscope'sAPI."""
def__init__(self,llm_config:dict=None): log.info(f"llm_config:{llm_config}") self.llm_config=llm_configor{} self.api_key=self.llm_config.get("api_key","") self.model=self.llm_config.get("model",dashscope.TextEmbedding.Models.text_embedding_v1)
asyncdef_execute_llm( self,input:EmbeddingInput,**kwargs:Unpack[LLMInput] )->EmbeddingOutput: log.info(f"input:{input}")
response=dashscope.TextEmbedding.call( model=self.model, input=input, api_key=self.api_key )
ifresponse.status_code==HTTPStatus.OK: res=[embedding["embedding"]forembeddinginresponse.output["embeddings"]] returnres else: raiseException(f"Error{response.code}:{response.message}") 通过刚才我们配置的运行方式,配置下 Qwen,运行下,然后可以通过 BigData Viwer 看到里面的内容:在 indexing-engine.log 里面也可以看到详细的内容在 GraphRAG 中,query 和 index 用的是不同的 BaseLLM 抽象类,并且在 Query 这里默认用的 OpenAIEmbeding,这里我们也修改一下。 importasyncio importlogging fromhttpimportHTTPStatus fromtypingimportAny
importdashscope fromtenacityimport( Retrying, RetryError, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter, )
fromgraphrag.query.llm.baseimportBaseLLMCallback,BaseLLM fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter
log=logging.getLogger(__name__)
classDashscopeGenerationLLM(BaseLLM): def__init__( self, api_key:str|None=None, model:str|None=None, max_retries:int=10, request_timeout:float=180.0, retry_error_types:tuple[type[BaseException]]=(Exception,), reporter:StatusReporter=ConsoleStatusReporter(), ): self.api_key=api_key self.model=modelordashscope.Generation.Models.qwen_turbo self.max_retries=max_retries self.request_timeout=request_timeout self.retry_error_types=retry_error_types self._reporter=reporter
defgenerate( self, messages:str|list[str], streaming:bool=False, callbacks:list[BaseLLMCallback]|None=None, **kwargs:Any, )->str: try: retryer=Retrying( stop=stop_after_attempt(self.max_retries), wait=wait_exponential_jitter(max=10), reraise=True, retry=retry_if_exception_type(self.retry_error_types), ) forattemptinretryer: withattempt: returnself._generate( messages=messages, streaming=streaming, callbacks=callbacks, **kwargs, ) exceptRetryErrorase: self._reporter.error( message="Erroratgenerate()",details={self.__class__.__name__:str(e)} ) return"" else: return""
asyncdefagenerate( self, messages:str|list[str], streaming:bool=False, callbacks:list[BaseLLMCallback]|None=None, **kwargs:Any, )->str: try: retryer=Retrying( stop=stop_after_attempt(self.max_retries), wait=wait_exponential_jitter(max=10), reraise=True, retry=retry_if_exception_type(self.retry_error_types), ) forattemptinretryer: withattempt: returnawaitasyncio.to_thread( self._generate, messages=messages, streaming=streaming, callbacks=callbacks, **kwargs, ) exceptRetryErrorase: self._reporter.error(f"Erroratagenerate():{e}") return"" else: return""
def_generate( self, messages:str|list[str], streaming:bool=False, callbacks:list[BaseLLMCallback]|None=None, **kwargs:Any, )->str: ifisinstance(messages,list): response=dashscope.Generation.call( model=self.model, messages=messages, api_key=self.api_key, stream=streaming, incremental_output=streaming, timeout=self.request_timeout, result_format='message', **kwargs, ) else: response=dashscope.Generation.call( model=self.model, prompt=messages, api_key=self.api_key, stream=streaming, incremental_output=streaming, timeout=self.request_timeout, **kwargs, )
#ifresponse.status_code!=HTTPStatus.OK: #raiseException(f"Error{response.code}:{response.message}")
ifstreaming: full_response="" forchunkinresponse: ifchunk.status_code!=HTTPStatus.OK: raiseException(f"Error{chunk.code}:{chunk.message}")
decoded_chunk=chunk.output.choices[0]['message']['content'] full_response+=decoded_chunk ifcallbacks: forcallbackincallbacks: callback.on_llm_new_token(decoded_chunk) returnfull_response else: ifisinstance(messages,list): returnresponse.output["choices"][0]["message"]["content"] else: returnresponse.output["text"]
实现 Query 的 Embedding 对象: importasyncio importlogging fromtypingimportAny
importdashscope fromtenacityimport( Retrying, RetryError, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter, )
fromgraphrag.query.llm.baseimportBaseTextEmbedding fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter
log=logging.getLogger(__name__)
classDashscopeEmbedding(BaseTextEmbedding):
def__init__( self, api_key:str|None=None, model:str=dashscope.TextEmbedding.Models.text_embedding_v1, max_retries:int=10, retry_error_types:tuple[type[BaseException]]=(Exception,), reporter:StatusReporter=ConsoleStatusReporter(), ): self.api_key=api_key self.model=model self.max_retries=max_retries self.retry_error_types=retry_error_types self._reporter=reporter
defembed(self,text:str,**kwargs:Any)->list[float]: try: embedding=self._embed_with_retry(text,**kwargs) returnembedding exceptExceptionase: self._reporter.error( message="Errorembeddingtext", details={self.__class__.__name__:str(e)}, ) return[]
asyncdefaembed(self,text:str,**kwargs:Any)->list[float]: try: embedding=awaitasyncio.to_thread(self._embed_with_retry,text,**kwargs) returnembedding exceptExceptionase: self._reporter.error( message="Errorembeddingtextasynchronously", details={self.__class__.__name__:str(e)}, ) return[]
def_embed_with_retry(self,text:str,**kwargs:Any)->list[float]: try: retryer=Retrying( stop=stop_after_attempt(self.max_retries), wait=wait_exponential_jitter(max=10), reraise=True, retry=retry_if_exception_type(self.retry_error_types), ) forattemptinretryer: withattempt: response=dashscope.TextEmbedding.call( model=self.model, input=text, api_key=self.api_key, **kwargs, ) ifresponse.status_code==200: embedding=response.output["embeddings"][0]["embedding"] returnembedding else: raiseException(f"Error{response.code}:{response.message}") exceptRetryErrorase: self._reporter.error( message="Erroratembed_with_retry()", details={self.__class__.__name__:str(e)}, ) return[]
运行下 Query 的效果: 
 上面已经提到了可以配置 pycharm 来配置 debug 断点; 在执行错误的时候,在 output 下面会有对应的执行详细信息,根据错误信息,可以在对应的地方加上断点查看错误的原因是什么。参考论文《From Local to Global: A Graph RAG Approach to Query-Focused Summarization》的描述:https://arxiv.org/pdf/2404.16130GraphRAG 的主要 Pipeline 步骤说明:1. 文本分块 (Source Documents → Text Chunks),将源文档分割成较小的文本块,每块大约 600 个 token。块之间有 100 个 token 的重叠,以保持上下文连贯性;2. 元素实例提取 (Text Chunks → Element Instances):使用 LLM 从每个文本块中提取实体、关系和声明。实体包括名称、类型和描述。关系包括源实体、目标实体和描述。使用多轮"gleaning"技术来提高提取质量;- "Gleaning" 是一种迭代式的信息提取方法。初始提取: LLM 首先对文本块进行一次实体和关系提取。评估:LLM 被要求评估是否所有实体都被提取出来了。迭代提取::如果 LLM 认为有遗漏,它会被提示进行额外的"gleaning"轮次,尝试提取之前可能遗漏的实体。多轮进行:这个过程可以重复多次,直到达到预设的最大轮次或 LLM 认为没有更多实体可提取。
create_final_entities.parquetcreate_final_nodes.parquetcreate_final_relationships.parquet3. 元素摘要生成 (Element Instances → Element Summaries):将相同元素的多个实例合并成单一的描述性文本块。4. 图社区检测 (Element Summaries → Graph Communities):将实体作为节点,关系作为边构建无向加权图。5. 社区摘要生成 (Graph Communities → Community Summaries):6. 查询回答 (Community Summaries → Community Answers → Global Answer):- Community Summaries(社区摘要):预先生成的,包含了图中每个社区(即相关实体群组)的概要信息。它们存储了关于每个主题领域的关键信息,通过问题找到一些相关的主题(社区摘要);
- Community Answers(社区回答):当收到用户查询时,系统会并行处理每个社区摘要,对每个社区摘要,系统会生成一个针对用户问题的部分答案,系统还会给每个部分答案评分,表示其对回答问题的相关性;
- Global Answer:系统会收集所有有用的部分答案(过滤掉评分为0的答案),然后,它会按照相关性评分对这些答案进行排序。最后,系统会综合这些部分答案,生成一个全面、连贯的最终答案。
这里主要想和大家分享下如何定制下 GraphRAG 支持千问模型,方便更多的同学体验下 GraphRAG,当然这还只是第一步,GraphRAG 还不能直接应用到真实的场景中。官网上有一些架构的设计理念和过程,也可以参考学习,看到 GraphRAG 的社区热度也挺高的,估计很快就可以作为一个相对成熟的方案引入到实际的系统中。为了方便学习,我把上述的改动代码上传到了代码仓库,感兴趣的同学可以试一下,也可以继续进行定制和优化...代码仓库地址:https://code.alibaba-inc.com/aihehe.ah/biz_graphrag接下来尝试下怎么支持自定义的向量存储以及调研下能否和业务集成,比如: |