导语:最近 GraphRAG 在社区很火,作者亲自体验后,发现了一些可以探讨和改进的地方,本文主要介绍了如何改造 GraphRAG 以支持自定义的 LLM。
01
为什么在 RAG 中引入知识图谱?
相比传统的 RAG 方法,Graph RAG 在处理全局性问题时表现出更好;
02
GraphRAG 改造计划
ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;color: rgb(86, 86, 86);font-size: 15px;letter-spacing: 1px;">GraphRAG 目前更像是一个 Demo 产品,想和业务结合现在也没什么可以操作的地方,肯定是需要自定义的。
这篇文章我会首先介绍下如何改造 GraphRAG 以支持自定义的 LLM,同时我把修改 GraphRAG 的代码也开源在 GitHub 上了,也欢迎感兴趣的朋友共同建设...
03
环境准备
3.1安装依赖
gitclonegit@github.com:microsoft/graphrag.git
#先安装pipx
brewinstallpipx
pipxensurepath
sudopipxensurepath--global#optionaltoallowpipxactionsinglobalscope.See"Globalinstallation"sectionbelow.
#安装poetry
pipxinstallpoetry
poetrycompletionszsh>~/.zfunc/_poetry
mkdir$ZSH_CUSTOM/plugins/poetry
poetrycompletionszsh>$ZSH_CUSTOM/plugins/poetry/_poetry
poetryinstall
3.2 项目结构
vector_stores 目录:包含向量数据库的实现。如果要自定义向量存储,需要在这个目录下进行实现。
3.3运行& Debug 项目
mkdir-p./ragtest/input
#这一步可以随便替换成一些其他的文档,小一点的,这样效率比较开,可以更快的验证下我们的改造结果
curlhttps://www.gutenberg.org/cache/epub/24022/pg24022.txt>./ragtest/input/book.txt
初始化项目:
python-mgraphrag.index--init--root./ragtest
对文档进行索引:
python-mgraphrag.index--root./ragtest
进行本地查询:
python-mgraphrag.query\
--root./ragtest\
--methodlocal\
"WhoisScrooge,andwhatarehismainrelationships?"
接下来填具体的参数,还有工作目录不要忘了。
04
GraphRAG 支持通义千问
4.1修改的内容
4.2 支持 Qwen 类型的配置
4.3使用 Qwen 进行 Index
def_load_qwen_llm(
on_error:ErrorHandlerFn,
cache
LMCache,
config:dict[str,Any],
azure=False,
):
log.info(f"LoadingQwencompletionLLMwithconfig{config}")
returnQwenCompletionLLM(config)
def_load_qwen_embeddings_llm(
on_error:ErrorHandlerFn,
cache
LMCache,
config:dict[str,Any],
azure=False,
):
log.info(f"LoadingQwenembeddingsLLMwithconfig{config}")
returnDashscopeEmbeddingsLLM(config);
通过兼容原本的方法,到这里索引部分就可以通过 Qwen 完全进行使用了。
#Copyright(c)2024MicrosoftCorporation.
#LicensedundertheMITLicense
importasyncio
importjson
importlogging
fromhttpimportHTTPStatus
fromtypingimportUnpack,List,Dict
importdashscope
importregexasre
fromgraphrag.configimportLLMType
fromgraphrag.llmimportLLMOutput
fromgraphrag.llm.baseimportBaseLLM
fromgraphrag.llm.base.base_llmimportTIn,TOut
fromgraphrag.llm.typesimport(
CompletionInput,
CompletionOutput,
LLMInput,
)
log=logging.getLogger(__name__)
classQwenCompletionLLM(
BaseLLM[
CompletionInput,
CompletionOutput,
]
):
def__init__(self,llm_config:dict=None):
log.info(f"llm_config:{llm_config}")
self.llm_config=llm_configor{}
self.api_key=self.llm_config.get("api_key","")
self.model=self.llm_config.get("model",dashscope.Generation.Models.qwen_turbo)
#self.chat_mode=self.llm_config.get("chat_mode",False)
self.llm_type=llm_config.get("type",LLMType.StaticResponse)
self.chat_mode=(llm_config.get("type",LLMType.StaticResponse)==LLMType.QwenChat)
asyncdef_execute_llm(
self,
input:CompletionInput,
**kwargs:Unpack[LLMInput],
)->CompletionOutput:
log.info(f"input:{input}")
log.info(f"kwargs:{kwargs}")
variables=kwargs.get("variables",{})
#使用字符串替换功能替换占位符
formatted_input=replace_placeholders(input,variables)
ifself.chat_mode:
history=kwargs.get("history",[])
messages=[
*history,
{"role":"user","content":formatted_input},
]
response=self.call_with_messages(messages)
else:
response=self.call_with_prompt(formatted_input)
ifresponse.status_code==HTTPStatus.OK:
ifself.chat_mode:
returnresponse.output["choices"][0]["message"]["content"]
else:
returnresponse.output["text"]
else:
raiseException(f"Error{response.code}:{response.message}")
defcall_with_prompt(self,query:str):
print("call_with_prompt{}".format(query))
response=dashscope.Generation.call(
model=self.model,
prompt=query,
api_key=self.api_key
)
returnresponse
defcall_with_messages(self,messages:list[dict[str,str]]):
print("call_with_messages{}".format(messages))
response=dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
result_format='message',
)
returnresponse
#主函数
asyncdef_invoke_json(self,input:TIn,**kwargs)->LLMOutput[TOut]:
try:
output=awaitself._execute_llm(input,**kwargs)
exceptExceptionase:
print(f"ErrorexecutingLLM:{e}")
returnLLMOutput[TOut](output=None,json=None)
#解析output的内容
extracted_jsons=extract_json_strings(output)
iflen(extracted_jsons)>0:
json_data=extracted_jsons[0]
else:
json_data=None
try:
output_str=json.dumps(json_data)
except(TypeError,ValueError)ase:
print(f"ErrorserializingJSON:{e}")
output_str=None
returnLLMOutput[TOut](
output=output_str,
json=json_data
)
defreplace_placeholders(input_str,variables):
forkey,valueinvariables.items():
placeholder="{"+key+"}"
input_str=input_str.replace(placeholder,value)
returninput_str
defpreprocess_input(input_str):
#预处理输入字符串,移除或转义特殊字符
returninput_str.replace('<','<').replace('>','>')
defextract_json_strings(input_string:str)->List[Dict]:
#正则表达式模式,用于匹配JSON对象
json_pattern=re.compile(r'(\{(?:[^{}]|(?R))*\})')
#查找所有匹配的JSON子字符串
matches=json_pattern.findall(input_string)
json_objects=[]
formatchinmatches:
try:
#尝试解析JSON子字符串
json_object=json.loads(match)
json_objects.append(json_object)
exceptjson.JSONDecodeError:
#如果解析失败,忽略此子字符串
log.warning(f"InvalidJSONstring:{match}")
pass
returnjson_objects
实现下对应的 Embeding 模型;
"""TheEmbeddingsLLMclass."""
importlogging
log=logging.getLogger(__name__)
fromtypingimportUnpack
fromgraphrag.llm.baseimportBaseLLM
fromgraphrag.llm.typesimport(
EmbeddingInput,
EmbeddingOutput,
LLMInput,
)
fromhttpimportHTTPStatus
importdashscope
importlogging
log=logging.getLogger(__name__)
classQwenEmbeddingsLLM(BaseLLM[EmbeddingInput,EmbeddingOutput]):
"""Atext-embeddinggeneratorLLMusingDashscope'sAPI."""
def__init__(self,llm_config:dict=None):
log.info(f"llm_config:{llm_config}")
self.llm_config=llm_configor{}
self.api_key=self.llm_config.get("api_key","")
self.model=self.llm_config.get("model",dashscope.TextEmbedding.Models.text_embedding_v1)
asyncdef_execute_llm(
self,input:EmbeddingInput,**kwargs:Unpack[LLMInput]
)->EmbeddingOutput:
log.info(f"input:{input}")
response=dashscope.TextEmbedding.call(
model=self.model,
input=input,
api_key=self.api_key
)
ifresponse.status_code==HTTPStatus.OK:
res=[embedding["embedding"]forembeddinginresponse.output["embeddings"]]
returnres
else:
raiseException(f"Error{response.code}:{response.message}")4.4使用 Qwen 进行 Query
query 相比 index 支持了流式的输出内容:
importasyncio
importlogging
fromhttpimportHTTPStatus
fromtypingimportAny
importdashscope
fromtenacityimport(
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
fromgraphrag.query.llm.baseimportBaseLLMCallback,BaseLLM
fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter
log=logging.getLogger(__name__)
classDashscopeGenerationLLM(BaseLLM):
def__init__(
self,
api_key:str|None=None,
model:str|None=None,
max_retries:int=10,
request_timeout:float=180.0,
retry_error_types:tuple[type[BaseException]]=(Exception,),
reporter:StatusReporter=ConsoleStatusReporter(),
):
self.api_key=api_key
self.model=modelordashscope.Generation.Models.qwen_turbo
self.max_retries=max_retries
self.request_timeout=request_timeout
self.retry_error_types=retry_error_types
self._reporter=reporter
defgenerate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
returnself._generate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
exceptRetryErrorase:
self._reporter.error(
message="Erroratgenerate()",details={self.__class__.__name__:str(e)}
)
return""
else:
return""
asyncdefagenerate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
returnawaitasyncio.to_thread(
self._generate,
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
exceptRetryErrorase:
self._reporter.error(f"Erroratagenerate():{e}")
return""
else:
return""
def_generate(
self,
messages:str|list[str],
streaming:bool=False,
callbacks:list[BaseLLMCallback]|None=None,
**kwargs:Any,
)->str:
ifisinstance(messages,list):
response=dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
result_format='message',
**kwargs,
)
else:
response=dashscope.Generation.call(
model=self.model,
prompt=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
**kwargs,
)
#ifresponse.status_code!=HTTPStatus.OK:
#raiseException(f"Error{response.code}:{response.message}")
ifstreaming:
full_response=""
forchunkinresponse:
ifchunk.status_code!=HTTPStatus.OK:
raiseException(f"Error{chunk.code}:{chunk.message}")
decoded_chunk=chunk.output.choices[0]['message']['content']
full_response+=decoded_chunk
ifcallbacks:
forcallbackincallbacks:
callback.on_llm_new_token(decoded_chunk)
returnfull_response
else:
ifisinstance(messages,list):
returnresponse.output["choices"][0]["message"]["content"]
else:
returnresponse.output["text"]
实现 Query 的 Embedding 对象:
importasyncio
importlogging
fromtypingimportAny
importdashscope
fromtenacityimport(
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
fromgraphrag.query.llm.baseimportBaseTextEmbedding
fromgraphrag.query.progressimportStatusReporter,ConsoleStatusReporter
log=logging.getLogger(__name__)
classDashscopeEmbedding(BaseTextEmbedding):
def__init__(
self,
api_key:str|None=None,
model:str=dashscope.TextEmbedding.Models.text_embedding_v1,
max_retries:int=10,
retry_error_types:tuple[type[BaseException]]=(Exception,),
reporter:StatusReporter=ConsoleStatusReporter(),
):
self.api_key=api_key
self.model=model
self.max_retries=max_retries
self.retry_error_types=retry_error_types
self._reporter=reporter
defembed(self,text:str,**kwargs:Any)->list[float]:
try:
embedding=self._embed_with_retry(text,**kwargs)
returnembedding
exceptExceptionase:
self._reporter.error(
message="Errorembeddingtext",
details={self.__class__.__name__:str(e)},
)
return[]
asyncdefaembed(self,text:str,**kwargs:Any)->list[float]:
try:
embedding=awaitasyncio.to_thread(self._embed_with_retry,text,**kwargs)
returnembedding
exceptExceptionase:
self._reporter.error(
message="Errorembeddingtextasynchronously",
details={self.__class__.__name__:str(e)},
)
return[]
def_embed_with_retry(self,text:str,**kwargs:Any)->list[float]:
try:
retryer=Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
forattemptinretryer:
withattempt:
response=dashscope.TextEmbedding.call(
model=self.model,
input=text,
api_key=self.api_key,
**kwargs,
)
ifresponse.status_code==200:
embedding=response.output["embeddings"][0]["embedding"]
returnembedding
else:
raiseException(f"Error{response.code}:{response.message}")
exceptRetryErrorase:
self._reporter.error(
message="Erroratembed_with_retry()",
details={self.__class__.__name__:str(e)},
)
return[]
运行下 Query 的效果:
4.5项目中的一些关键节点
4.6遇到错误怎么办
05
GraphRAG 的核心步骤
使用 Leiden 算法进行检测,得到层次化的社区结构,Leiden 算法帮助我们把大量的文本信息组织成有意义的群组,使得我们可以更容易地理解和处理这些信息。
对于高层社区,递归地利用子社区摘要。
06
小结
业务集成GraphRAG
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |