171 lines
7.0 KiB
Python
171 lines
7.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
import os
|
||
import sys
|
||
import logging
|
||
import json
|
||
import re
|
||
from http import HTTPStatus
|
||
import dashscope
|
||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import numpy as np
|
||
from dotenv import load_dotenv
|
||
import threading
|
||
|
||
# 1. --- 导入MCP SDK ---
|
||
MCP_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
SDK_PATH = os.path.join(MCP_DIR, 'python-sdk', 'src')
|
||
if SDK_PATH not in sys.path:
|
||
sys.path.insert(0, SDK_PATH)
|
||
from mcp.server.fastmcp import FastMCP
|
||
|
||
# 2. --- 日志、缓存和知识库配置 ---
|
||
LOG_FILE_PATH = os.path.join(MCP_DIR, 'mcpserver_parallel.log')
|
||
logging.basicConfig(filename=LOG_FILE_PATH, level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
encoding='utf-8')
|
||
|
||
env_path = os.path.join(MCP_DIR, '.env')
|
||
load_dotenv(dotenv_path=env_path)
|
||
|
||
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
||
DASHSCOPE_APP_ID = os.getenv("DASHSCOPE_APP_ID")
|
||
dashscope.api_key = DASHSCOPE_API_KEY
|
||
|
||
if not DASHSCOPE_API_KEY or not DASHSCOPE_APP_ID:
|
||
error_msg = "错误:请确保 MCP/.env 文件中已正确配置 DASHSCOPE_API_KEY 和 DASHSCOPE_APP_ID。"
|
||
logging.error(error_msg)
|
||
sys.exit(error_msg)
|
||
|
||
KNOWLEDGE_BASE_PATHS = [r"C:\Steam\steamapps\common\RimWorld\Data"]
|
||
|
||
# 3. --- 辅助函数 ---
|
||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
|
||
def get_embedding(text: str):
|
||
response = dashscope.TextEmbedding.call(model='text-embedding-v2', input=text)
|
||
return response.output['embeddings'][0]['embedding'] if response.status_code == HTTPStatus.OK else None
|
||
|
||
def find_files_with_keyword(base_paths: list[str], keywords: list[str]) -> list[str]:
|
||
found_files = set()
|
||
for base_path in base_paths:
|
||
for root, _, files in os.walk(base_path):
|
||
for file in files:
|
||
if any(keyword.lower() in file.lower() for keyword in keywords):
|
||
found_files.add(os.path.join(root, file))
|
||
return list(found_files)
|
||
|
||
def rerank_files(question, file_paths, top_n=5):
|
||
documents, valid_paths = [], []
|
||
for path in file_paths:
|
||
try:
|
||
with open(path, 'r', encoding='utf-8') as f: documents.append(f.read(2000))
|
||
valid_paths.append(path)
|
||
except Exception: continue
|
||
if not documents: return []
|
||
response = dashscope.TextReRank.call(model='gte-rerank', query=question, documents=documents, top_n=top_n)
|
||
return [valid_paths[r['index']] for r in response.output['results']] if response.status_code == HTTPStatus.OK else valid_paths[:top_n]
|
||
|
||
def extract_full_code_block(file_path, keyword):
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f: lines = f.read().split('\n')
|
||
found_line_index = -1
|
||
for i, line in enumerate(lines):
|
||
if re.search(r'\b' + re.escape(keyword) + r'\b', line, re.IGNORECASE):
|
||
found_line_index = i
|
||
break
|
||
if found_line_index == -1: return ""
|
||
if file_path.endswith(('.cs', '.txt')):
|
||
start_index, brace_count = -1, 0
|
||
for i in range(found_line_index, -1, -1):
|
||
if 'class ' in lines[i] or 'struct ' in lines[i] or 'enum ' in lines[i]: start_index = i; break
|
||
if start_index == -1: return ""
|
||
end_index = -1
|
||
for i in range(start_index, len(lines)):
|
||
brace_count += lines[i].count('{') - lines[i].count('}')
|
||
if brace_count == 0 and '{' in "".join(lines[start_index:i+1]): end_index = i; break
|
||
return "\n".join(lines[start_index:end_index+1]) if end_index != -1 else ""
|
||
elif file_path.endswith('.xml'): return "\n".join(lines)
|
||
except Exception as e: logging.error(f"提取代码时出错 {file_path}: {e}"); return ""
|
||
|
||
def find_keywords_in_question(question: str) -> list[str]:
|
||
return list(set(re.findall(r'\b[A-Z][a-zA-Z0-9_]*\b', question))) or [question.split(" ")[-1]]
|
||
|
||
# 4. --- 并行任务定义 ---
|
||
|
||
def get_cloud_response(question, result_container):
|
||
logging.info("开始请求云端智能体...")
|
||
try:
|
||
response = dashscope.Application.call(app_id=DASHSCOPE_APP_ID, api_key=DASHSCOPE_API_KEY, prompt=question)
|
||
if response.status_code == HTTPStatus.OK:
|
||
result_container['cloud'] = response.output.text
|
||
logging.info("成功获取云端智能体回复。")
|
||
else:
|
||
result_container['cloud'] = f"云端智能体请求失败: {response.message}"
|
||
logging.error(f"云端智能体请求失败: {response.message}")
|
||
except Exception as e:
|
||
result_container['cloud'] = f"云端智能体请求异常: {e}"
|
||
logging.error(f"云端智能体请求异常: {e}", exc_info=True)
|
||
|
||
def get_local_rag_response(question, result_container):
|
||
logging.info("开始本地RAG流程...")
|
||
try:
|
||
keywords = find_keywords_in_question(question)
|
||
found_files = find_files_with_keyword(KNOWLEDGE_BASE_PATHS, keywords)
|
||
if not found_files:
|
||
result_container['local'] = "本地RAG:未找到相关文件。"
|
||
return
|
||
|
||
reranked_files = rerank_files(question, found_files)
|
||
|
||
context_blocks = []
|
||
for file_path in reranked_files:
|
||
for keyword in keywords:
|
||
if keyword.lower() in file_path.lower():
|
||
code = extract_full_code_block(file_path, keyword)
|
||
if code:
|
||
header = f"\n--- 本地文件: {os.path.basename(file_path)} ---\n"
|
||
context_blocks.append(header + code)
|
||
break
|
||
result_container['local'] = "\n".join(context_blocks) if context_blocks else "本地RAG:找到文件但未能提取代码。"
|
||
logging.info("本地RAG流程完成。")
|
||
except Exception as e:
|
||
result_container['local'] = f"本地RAG流程异常: {e}"
|
||
logging.error(f"本地RAG流程异常: {e}", exc_info=True)
|
||
|
||
# 5. --- MCP服务器 ---
|
||
mcp = FastMCP(name="rimworld-knowledge-base")
|
||
|
||
@mcp.tool()
|
||
def get_context(question: str) -> str:
|
||
"""并行获取云端分析和本地代码,并组合输出。"""
|
||
results = {}
|
||
|
||
# 创建并启动线程
|
||
cloud_thread = threading.Thread(target=get_cloud_response, args=(question, results))
|
||
local_thread = threading.Thread(target=get_local_rag_response, args=(question, results))
|
||
|
||
cloud_thread.start()
|
||
local_thread.start()
|
||
|
||
# 等待两个线程完成
|
||
cloud_thread.join()
|
||
local_thread.join()
|
||
|
||
# 组合结果
|
||
cloud_result = results.get('cloud', "未能获取云端回复。")
|
||
local_result = results.get('local', "未能获取本地代码。")
|
||
|
||
final_response = (
|
||
f"--- 云端智能体分析 ---\n\n{cloud_result}\n\n"
|
||
f"====================\n\n"
|
||
f"--- 本地完整代码参考 ---\n{local_result}"
|
||
)
|
||
|
||
return final_response
|
||
|
||
# 6. --- 启动服务器 ---
|
||
if __name__ == "__main__":
|
||
logging.info("启动并行模式MCP服务器...")
|
||
mcp.run()
|
||
logging.info("并行模式MCP服务器已停止。") |