为GraphRAG快速构建http api

15 天前(已编辑)
/ , , ,
35

阅读此文章之前,你可能需要首先阅读以下的文章才能更好的理解上下文。

为GraphRAG快速构建http api

之前使用GraphRAG生成了自己的知识库,但是默认只能使用命令行查询,官方没做http api,所以自己用fastapi写了一个API wrapper。

https://microsoft.github.io/graphrag/posts/query/notebooks/global_search_nb/

根据官方提供的notebook内容,修改成如下

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import pandas as pd
import tiktoken
from graphrag.query.indexer_adapters import read_indexer_entities, read_indexer_reports
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch

app = FastAPI()

# Define the request model
class QueryRequest(BaseModel):
    query: str

# Load data and initialize components
INPUT_DIR = "./inputs/parquets"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
COMMUNITY_LEVEL = 2

entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")

reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)

token_encoder = tiktoken.get_encoding("cl100k_base")

context_builder = GlobalCommunityContext(
    community_reports=reports,
    entities=entities,
    token_encoder=token_encoder,
)

api_key = os.environ["GRAPHRAG_API_KEY"]
api_base = os.environ["GRAPHRAG_API_BASE"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]

llm = ChatOpenAI(
    api_base=api_base,
    api_key=api_key,
    model=llm_model,
    api_type=OpenaiApiType.OpenAI,
    max_retries=20,
)

context_builder_params = {
    "use_community_summary": False,
    "shuffle_data": True,
    "include_community_rank": True,
    "min_community_rank": 0,
    "community_rank_name": "rank",
    "include_community_weight": True,
    "community_weight_name": "occurrence weight",
    "normalize_community_weight": True,
    "max_tokens": 12_000,
    "context_name": "Reports",
}

map_llm_params = {
    "max_tokens": 1000,
    "temperature": 0.0,
    "response_format": {"type": "json_object"},
}

reduce_llm_params = {
    "max_tokens": 2000,
    "temperature": 0.0,
}

search_engine = GlobalSearch(
    llm=llm,
    context_builder=context_builder,
    token_encoder=token_encoder,
    max_data_tokens=12_000,
    map_llm_params=map_llm_params,
    reduce_llm_params=reduce_llm_params,
    allow_general_knowledge=False,
    json_mode=True,
    context_builder_params=context_builder_params,
    concurrent_coroutines=32,
    response_type="multiple paragraphs",
)

@app.post("/search/")
async def perform_search(request: QueryRequest):
    try:
        result = await search_engine.asearch(request.query)
        return {
            "response": result.response,
            "llm_calls": result.llm_calls,
            "prompt_tokens": result.prompt_tokens,
            "context_data": result.context_data["reports"]
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Run the FastAPI application
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

主要根据自己需求修改llm部分

api_key = os.environ["GRAPHRAG_API_KEY"]
api_base = os.environ["GRAPHRAG_API_BASE"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]

修改graphrag数据库部分,指向parquets文件所在目录

INPUT_DIR = "./inputs/parquets"

安装需要的package后直接运行就可以使用

curl -X POST -H "Content-Type: application/json" -d '{"query":"主人公有谁"}' http://127.0.0.1:8000/search
  • 返回一个字典,包含以下键值对:
    • "response": 搜索结果的响应文本。
    • "llm_calls": LLM 调用的次数。
    • "prompt_tokens": 使用的token。
    • "context_data": 用于构建 LLM 响应的上下文内容。
  • Loading...
  • Loading...
  • Loading...
  • Loading...
  • Loading...