为GraphRAG快速构建http api
之前使用GraphRAG生成了自己的知识库,但是默认只能使用命令行查询,官方没做http api,所以自己用fastapi写了一个API wrapper。
https://microsoft.github.io/graphrag/posts/query/notebooks/global_search_nb/
根据官方提供的notebook内容,修改成如下
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import pandas as pd
import tiktoken
from graphrag.query.indexer_adapters import read_indexer_entities, read_indexer_reports
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
app = FastAPI()
# Define the request model
class QueryRequest(BaseModel):
query: str
# Load data and initialize components
INPUT_DIR = "./inputs/parquets"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
COMMUNITY_LEVEL = 2
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
token_encoder = tiktoken.get_encoding("cl100k_base")
context_builder = GlobalCommunityContext(
community_reports=reports,
entities=entities,
token_encoder=token_encoder,
)
api_key = os.environ["GRAPHRAG_API_KEY"]
api_base = os.environ["GRAPHRAG_API_BASE"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
llm = ChatOpenAI(
api_base=api_base,
api_key=api_key,
model=llm_model,
api_type=OpenaiApiType.OpenAI,
max_retries=20,
)
context_builder_params = {
"use_community_summary": False,
"shuffle_data": True,
"include_community_rank": True,
"min_community_rank": 0,
"community_rank_name": "rank",
"include_community_weight": True,
"community_weight_name": "occurrence weight",
"normalize_community_weight": True,
"max_tokens": 12_000,
"context_name": "Reports",
}
map_llm_params = {
"max_tokens": 1000,
"temperature": 0.0,
"response_format": {"type": "json_object"},
}
reduce_llm_params = {
"max_tokens": 2000,
"temperature": 0.0,
}
search_engine = GlobalSearch(
llm=llm,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=12_000,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=False,
json_mode=True,
context_builder_params=context_builder_params,
concurrent_coroutines=32,
response_type="multiple paragraphs",
)
@app.post("/search/")
async def perform_search(request: QueryRequest):
try:
result = await search_engine.asearch(request.query)
return {
"response": result.response,
"llm_calls": result.llm_calls,
"prompt_tokens": result.prompt_tokens,
"context_data": result.context_data["reports"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Run the FastAPI application
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
主要根据自己需求修改llm部分
api_key = os.environ["GRAPHRAG_API_KEY"]
api_base = os.environ["GRAPHRAG_API_BASE"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
修改graphrag数据库部分,指向parquets文件所在目录
INPUT_DIR = "./inputs/parquets"
安装需要的package后直接运行就可以使用
curl -X POST -H "Content-Type: application/json" -d '{"query":"主人公有谁"}' http://127.0.0.1:8000/search
- 返回一个字典,包含以下键值对:
"response"
: 搜索结果的响应文本。"llm_calls"
: LLM 调用的次数。"prompt_tokens"
: 使用的token。"context_data"
: 用于构建 LLM 响应的上下文内容。