暂存mcp重写
This commit is contained in:
488
MCP/mcpserver_stdio_complete.py
Normal file
488
MCP/mcpserver_stdio_complete.py
Normal file
@@ -0,0 +1,488 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import glob
|
||||
from http import HTTPStatus
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import threading
|
||||
|
||||
import dashscope
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
# 1. --- 配置与初始化 ---
|
||||
|
||||
# 路径配置
|
||||
MCP_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
LOG_FILE_PATH = os.path.join(MCP_DIR, 'rimworld_rag.log')
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(filename=LOG_FILE_PATH, level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
encoding='utf-8')
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(os.path.join(MCP_DIR, '.env'))
|
||||
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
||||
if not DASHSCOPE_API_KEY:
|
||||
logging.warning("Missing DASHSCOPE_API_KEY. Semantic reranking will be disabled.")
|
||||
else:
|
||||
dashscope.api_key = DASHSCOPE_API_KEY
|
||||
|
||||
# 数据根目录
|
||||
# 根据之前的 list_files 结果调整路径
|
||||
RIMWORLD_DATA_ROOT = os.path.abspath(os.path.join(MCP_DIR, "..", "..", "..", "Data"))
|
||||
RIMWORLD_SOURCE_ROOT = os.path.abspath(os.path.join(MCP_DIR, "..", "..", "..", "dll1.6"))
|
||||
|
||||
logging.info(f"Data Root: {RIMWORLD_DATA_ROOT}")
|
||||
logging.info(f"Source Root: {RIMWORLD_SOURCE_ROOT}")
|
||||
|
||||
# 2. --- 核心索引类 (内存版) ---
|
||||
|
||||
class SymbolIndex:
|
||||
"""
|
||||
内存索引构建器。
|
||||
替代 C# 项目中的 Lucene/MetadataStore。
|
||||
启动时扫描所有文件,建立 Symbol -> FilePath 的映射。
|
||||
"""
|
||||
def __init__(self):
|
||||
self.symbol_map = {} # symbol_id -> file_path
|
||||
self.files_cache = [] # 所有文件路径列表
|
||||
self.is_initialized = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def initialize(self):
|
||||
with self._lock:
|
||||
if self.is_initialized: return
|
||||
logging.info("正在构建内存索引...")
|
||||
start_t = time.time()
|
||||
|
||||
# 1. 扫描 C# 源码
|
||||
if os.path.exists(RIMWORLD_SOURCE_ROOT):
|
||||
for root, _, files in os.walk(RIMWORLD_SOURCE_ROOT):
|
||||
for file in files:
|
||||
if file.endswith(('.cs', '.txt')):
|
||||
full_path = os.path.join(root, file)
|
||||
# 假设文件名即类名 (简化逻辑)
|
||||
symbol = os.path.splitext(file)[0]
|
||||
self.symbol_map[symbol] = full_path
|
||||
self.files_cache.append(full_path)
|
||||
else:
|
||||
logging.warning(f"Source root not found: {RIMWORLD_SOURCE_ROOT}")
|
||||
|
||||
# 2. 扫描 XML Defs
|
||||
# 扫描 Data 目录下的所有子目录
|
||||
if os.path.exists(RIMWORLD_DATA_ROOT):
|
||||
for root, _, files in os.walk(RIMWORLD_DATA_ROOT):
|
||||
for file in files:
|
||||
if file.endswith('.xml'):
|
||||
full_path = os.path.join(root, file)
|
||||
self.files_cache.append(full_path)
|
||||
# 快速解析 XML 找 defName
|
||||
try:
|
||||
self._scan_xml_defs(full_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"解析 XML 失败 {full_path}: {e}")
|
||||
else:
|
||||
logging.warning(f"Data root not found: {RIMWORLD_DATA_ROOT}")
|
||||
|
||||
logging.info(f"索引构建完成,耗时 {time.time() - start_t:.2f}s,收录符号 {len(self.symbol_map)} 个")
|
||||
self.is_initialized = True
|
||||
|
||||
def _scan_xml_defs(self, path):
|
||||
content = read_file_content(path)
|
||||
if not content: return
|
||||
|
||||
# 使用正则快速提取 defName
|
||||
def_names = re.findall(r'<defName>(.*?)</defName>', content)
|
||||
for name in def_names:
|
||||
self.symbol_map[f"xml:{name}"] = path
|
||||
|
||||
def search_symbols(self, keyword: str, kind: str = None) -> List[Tuple[str, str]]:
|
||||
"""简单的关键词匹配"""
|
||||
results = []
|
||||
kw_lower = keyword.lower()
|
||||
for sym, path in self.symbol_map.items():
|
||||
if kind == 'csharp' and sym.startswith('xml:'): continue
|
||||
if kind == 'xml' and not sym.startswith('xml:'): continue
|
||||
|
||||
if kw_lower in sym.lower():
|
||||
results.append((sym, path))
|
||||
# 移除硬限制,以便后续排序能找到最佳结果
|
||||
# if len(results) > 200: break
|
||||
return results
|
||||
|
||||
# 全局索引实例
|
||||
global_index = SymbolIndex()
|
||||
|
||||
# 3. --- 辅助工具函数 ---
|
||||
|
||||
def read_file_content(path: str) -> str:
|
||||
"""健壮的文件读取"""
|
||||
encodings = ['utf-8', 'utf-8-sig', 'gbk', 'latin-1']
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(path, 'r', encoding=enc) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
return ""
|
||||
|
||||
def extract_xml_fragment(file_path: str, def_name: str) -> str:
|
||||
"""提取 XML 中特定的 Def 块"""
|
||||
content = read_file_content(file_path)
|
||||
try:
|
||||
# 尝试正则匹配整个 Def 块
|
||||
# 假设 Def 格式为 <DefType defName="NAME">...</DefType> 或 <DefType><defName>NAME</defName>...</DefType>
|
||||
|
||||
# 策略 1: 查找 <defName>NAME</defName>,然后向上找最近的 <DefType>
|
||||
pattern = r"<(\w+)\s*(?:Name=\"[^\"]*\")?>\s*<defName>" + re.escape(def_name) + r"</defName>"
|
||||
match = re.search(pattern, content)
|
||||
|
||||
if not match:
|
||||
# 策略 2: 查找 defName="NAME"
|
||||
pattern = r"<(\w+)\s+[^>]*defName=\"" + re.escape(def_name) + r"\""
|
||||
match = re.search(pattern, content)
|
||||
|
||||
if match:
|
||||
tag_name = match.group(1)
|
||||
# 找到开始位置
|
||||
start_pos = match.start()
|
||||
# 寻找对应的结束标签 </Tag>
|
||||
# 这是一个简化的查找,不支持嵌套同名标签,但在 RimWorld Defs 中通常足够
|
||||
end_tag = f"</{tag_name}>"
|
||||
end_pos = content.find(end_tag, start_pos)
|
||||
if end_pos != -1:
|
||||
return content[start_pos:end_pos + len(end_tag)]
|
||||
except Exception as e:
|
||||
logging.error(f"Error extracting XML fragment: {e}")
|
||||
|
||||
# 降级方案:返回文件内容,如果太长则截断
|
||||
lines = content.split('\n')
|
||||
if len(lines) > 100:
|
||||
return "\n".join(lines[:100]) + "\n... (XML too long, truncated)"
|
||||
return content
|
||||
|
||||
def extract_csharp_fragment(file_path: str, symbol: str) -> str:
|
||||
"""提取 C# 类或方法"""
|
||||
content = read_file_content(file_path)
|
||||
# 简化:直接返回整个文件,如果太大则截断
|
||||
lines = content.split('\n')
|
||||
if len(lines) > 500:
|
||||
return "\n".join(lines[:500]) + "\n\n// ... (File too long, truncated)"
|
||||
return content
|
||||
|
||||
# 4. --- 功能实现类 ---
|
||||
|
||||
class RoughSearcher:
|
||||
"""模仿 C# 项目的 RoughSearcher,结合关键词过滤和语义重排序"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
global_index.initialize()
|
||||
|
||||
def search(self, query: str) -> List[Dict]:
|
||||
# 1. 粗筛:在文件名和 Symbol 中查找
|
||||
candidates = global_index.search_symbols(query, self.config.get('kind'))
|
||||
|
||||
# 如果没有找到直接匹配,尝试更宽松的搜索(如分词)
|
||||
if not candidates:
|
||||
tokens = query.split()
|
||||
if len(tokens) > 1:
|
||||
candidates = global_index.search_symbols(tokens[0], self.config.get('kind'))
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# 优化:对候选结果进行预排序
|
||||
# 优先级:完全匹配 > 前缀匹配 > 长度更短 > 字母顺序
|
||||
candidates.sort(key=lambda x: (
|
||||
x[0].lower() != query.lower(),
|
||||
not x[0].lower().startswith(query.lower()),
|
||||
len(x[0]),
|
||||
x[0]
|
||||
))
|
||||
|
||||
# 2. 准备重排序文档
|
||||
docs = []
|
||||
valid_candidates = []
|
||||
|
||||
# 增加重排数量到 50,提高召回率
|
||||
for sym, path in candidates[:50]:
|
||||
content = read_file_content(path)
|
||||
# 截取一部分内容用于语义判断
|
||||
snippet = content[:1000]
|
||||
docs.append(f"Title: {sym}\nContent: {snippet}")
|
||||
valid_candidates.append((sym, path))
|
||||
|
||||
# 3. 使用 DashScope Rerank (如果配置了key)
|
||||
ranked_results = []
|
||||
if DASHSCOPE_API_KEY and docs:
|
||||
try:
|
||||
response = dashscope.TextReRank.call(
|
||||
model='gte-rerank',
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.config.get('max_results', 10)
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
for res in response.output['results']:
|
||||
idx = res['index']
|
||||
sym, path = valid_candidates[idx]
|
||||
score = res['score']
|
||||
ranked_results.append(self._format_result(sym, path, score))
|
||||
else:
|
||||
logging.error(f"Rerank API error: {response}")
|
||||
# 降级:直接返回
|
||||
ranked_results = [self._format_result(s, p, 0.5) for s, p in valid_candidates]
|
||||
except Exception as e:
|
||||
logging.error(f"Rerank exception: {e}")
|
||||
ranked_results = [self._format_result(s, p, 0.5) for s, p in valid_candidates]
|
||||
else:
|
||||
# 无 Key 模式:按名称长度排序作为简单 heuristic
|
||||
valid_candidates.sort(key=lambda x: len(x[0]))
|
||||
ranked_results = [self._format_result(s, p, 1.0) for s, p in valid_candidates[:self.config.get('max_results', 10)]]
|
||||
|
||||
return ranked_results
|
||||
|
||||
def _format_result(self, symbol, path, score):
|
||||
is_xml = symbol.startswith('xml:')
|
||||
return {
|
||||
"symbolId": symbol,
|
||||
"kind": "xml" if is_xml else "csharp",
|
||||
"symbolKind": "definition" if is_xml else "class",
|
||||
"path": path,
|
||||
"title": symbol,
|
||||
"score": float(score),
|
||||
"preview": "Use get_item to view content"
|
||||
}
|
||||
|
||||
class GraphQuerier:
|
||||
"""
|
||||
模拟 C# 项目的 GraphQuerier。
|
||||
由于没有预计算的 .bin 图数据,这里使用实时解析 (Heuristic) 实现。
|
||||
"""
|
||||
def __init__(self):
|
||||
global_index.initialize()
|
||||
|
||||
def query_uses(self, symbol: str, kind: str = "all") -> List[Dict]:
|
||||
"""查询下游依赖 (Symbol 引用了什么)"""
|
||||
if symbol not in global_index.symbol_map:
|
||||
return []
|
||||
|
||||
file_path = global_index.symbol_map[symbol]
|
||||
content = read_file_content(file_path)
|
||||
edges = []
|
||||
|
||||
if symbol.startswith('xml:'):
|
||||
# XML 解析逻辑
|
||||
# 1. 查找 <ParentName> (Inherits)
|
||||
parents = re.findall(r'ParentName="([^"]+)"', content)
|
||||
parents += re.findall(r'<ParentName>([^<]+)</ParentName>', content)
|
||||
for p in parents:
|
||||
edges.append({"target": f"xml:{p}", "kind": "Inherits"})
|
||||
|
||||
# 2. 查找 <xxxClass> (BindsClass)
|
||||
classes = re.findall(r'<[\w]+Class>([\w\.]+)</[\w]+Class>', content) # 如 <thingClass>
|
||||
classes += re.findall(r'Class="([\w\.]+)"', content) # 如 <li Class="...">
|
||||
for c in classes:
|
||||
# 简单的命名空间推断
|
||||
target_cls = c
|
||||
if '.' not in c:
|
||||
# 尝试推断,通常是 RimWorld 或 Verse
|
||||
if f"RimWorld.{c}" in global_index.symbol_map: target_cls = f"RimWorld.{c}"
|
||||
elif f"Verse.{c}" in global_index.symbol_map: target_cls = f"Verse.{c}"
|
||||
|
||||
edges.append({"target": target_cls, "kind": "XmlBindsClass"})
|
||||
|
||||
# 3. 查找 <defName> 引用 (References)
|
||||
potential_refs = re.findall(r'>([\w]+)</', content)
|
||||
for ref in potential_refs:
|
||||
xml_ref = f"xml:{ref}"
|
||||
if xml_ref in global_index.symbol_map and xml_ref != symbol:
|
||||
edges.append({"target": xml_ref, "kind": "XmlReferences"})
|
||||
|
||||
else:
|
||||
# C# 解析逻辑 (Regex)
|
||||
# 1. 继承 : BaseClass
|
||||
inherits = re.search(r'class\s+\w+\s*:\s*([\w\.]+)', content)
|
||||
if inherits:
|
||||
edges.append({"target": inherits.group(1), "kind": "Inherits"})
|
||||
|
||||
# 2. 字段/方法调用 (References) - 简化版:查找所有大写开头的单词
|
||||
tokens = set(re.findall(r'\b[A-Z]\w+\b', content))
|
||||
for t in tokens:
|
||||
if t in global_index.symbol_map and t != symbol:
|
||||
edges.append({"target": t, "kind": "References"})
|
||||
elif f"RimWorld.{t}" in global_index.symbol_map:
|
||||
edges.append({"target": f"RimWorld.{t}", "kind": "References"})
|
||||
elif f"Verse.{t}" in global_index.symbol_map:
|
||||
edges.append({"target": f"Verse.{t}", "kind": "References"})
|
||||
|
||||
# 过滤 Kind
|
||||
if kind != 'all':
|
||||
filtered = []
|
||||
for e in edges:
|
||||
is_xml_target = e['target'].startswith('xml:')
|
||||
if kind == 'xml' and is_xml_target: filtered.append(e)
|
||||
elif kind == 'csharp' and not is_xml_target: filtered.append(e)
|
||||
return filtered
|
||||
|
||||
return edges
|
||||
|
||||
def query_used_by(self, symbol: str, kind: str = "all") -> List[Dict]:
|
||||
"""
|
||||
查询上游依赖 (谁使用了 Symbol)。
|
||||
这个操作非常昂贵 (全文件扫描),C# 使用了倒排索引。
|
||||
Python 版这里只能用 grep (glob scan) 模拟,速度会慢。
|
||||
"""
|
||||
results = []
|
||||
search_term = symbol.replace('xml:', '')
|
||||
|
||||
# 限制扫描文件数以防超时,优先扫描同类型文件
|
||||
scan_xml = kind in ['all', 'xml']
|
||||
scan_cs = kind in ['all', 'csharp']
|
||||
|
||||
# 遍历所有已知文件进行文本匹配
|
||||
count = 0
|
||||
for file_path in global_index.files_cache:
|
||||
is_xml_file = file_path.endswith('.xml')
|
||||
if is_xml_file and not scan_xml: continue
|
||||
if not is_xml_file and not scan_cs: continue
|
||||
|
||||
content = read_file_content(file_path)
|
||||
if search_term in content:
|
||||
# 找到了引用
|
||||
# 尝试反推该文件的 Symbol ID
|
||||
source_symbol = "Unknown"
|
||||
if is_xml_file:
|
||||
# 尝试找 defName
|
||||
m = re.search(r'<defName>(.*?)</defName>', content)
|
||||
if m: source_symbol = f"xml:{m.group(1)}"
|
||||
else:
|
||||
# C# 文件名即 Symbol
|
||||
fname = os.path.basename(file_path)
|
||||
source_symbol = os.path.splitext(fname)[0]
|
||||
|
||||
if source_symbol != "Unknown" and source_symbol != symbol:
|
||||
results.append({
|
||||
"source": source_symbol,
|
||||
"kind": "References" if not is_xml_file else "XmlReferences",
|
||||
"distance": 1
|
||||
})
|
||||
count += 1
|
||||
if count >= 50: break # 限制数量
|
||||
|
||||
return results
|
||||
|
||||
# 5. --- MCP Server 定义 ---
|
||||
|
||||
mcp = FastMCP(name="rimworld-code-rag")
|
||||
|
||||
@mcp.tool()
|
||||
def rough_search(query: str, kind: str = None, max_results: int = 20) -> Dict[str, Any]:
|
||||
"""
|
||||
粗略搜索工具:使用自然语言查询 RimWorld 代码符号和 XML 定义。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词,如 "weapon gun" 或 "pawn health"
|
||||
kind: 过滤类型,可选 "csharp" (或 "cs") 或 "xml" (或 "def")
|
||||
max_results: 最大返回结果数
|
||||
"""
|
||||
logging.info(f"rough_search: {query} (kind={kind})")
|
||||
searcher = RoughSearcher({"kind": kind, "max_results": max_results})
|
||||
results = searcher.search(query)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"totalFound": len(results)
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def get_item(symbol: str, max_lines: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
精确检索工具:获取特定符号的完整源代码或 Definition。
|
||||
|
||||
Args:
|
||||
symbol: 符号ID,例如 "RimWorld.Pawn" 或 "xml:Gun_Revolver"
|
||||
max_lines: 最大返回行数,0 表示不限制
|
||||
"""
|
||||
logging.info(f"get_item: {symbol}")
|
||||
global_index.initialize()
|
||||
|
||||
if symbol not in global_index.symbol_map:
|
||||
return {"error": f"未找到符号: {symbol},请先使用 rough_search 确认名称。"}
|
||||
|
||||
file_path = global_index.symbol_map[symbol]
|
||||
|
||||
if symbol.startswith("xml:"):
|
||||
source_code = extract_xml_fragment(file_path, symbol.replace("xml:", ""))
|
||||
else:
|
||||
source_code = extract_csharp_fragment(file_path, symbol)
|
||||
|
||||
# 行数限制
|
||||
lines = source_code.split('\n')
|
||||
if max_lines > 0 and len(lines) > max_lines:
|
||||
source_code = "\n".join(lines[:max_lines]) + f"\n... (剩余 {len(lines)-max_lines} 行已截断)"
|
||||
|
||||
return {
|
||||
"symbolId": symbol,
|
||||
"path": file_path,
|
||||
"language": "xml" if symbol.startswith("xml:") else "csharp",
|
||||
"sourceCode": source_code,
|
||||
"totalLines": len(lines)
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def get_uses(symbol: str, kind: str = "all", max_results: int = 50) -> Dict[str, Any]:
|
||||
"""
|
||||
依赖分析:查找该符号引用了什么(下游依赖)。
|
||||
|
||||
Args:
|
||||
symbol: 符号ID
|
||||
kind: 过滤目标类型 ("csharp", "xml", "all")
|
||||
"""
|
||||
logging.info(f"get_uses: {symbol}")
|
||||
querier = GraphQuerier()
|
||||
edges = querier.query_uses(symbol, kind)
|
||||
|
||||
return {
|
||||
"sourceSymbol": symbol,
|
||||
"edges": edges[:max_results],
|
||||
"total": len(edges)
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def get_used_by(symbol: str, kind: str = "all", max_results: int = 20) -> Dict[str, Any]:
|
||||
"""
|
||||
反向依赖分析:查找谁使用了该符号(上游依赖)。
|
||||
注意:由于没有预计算索引,此操作涉及文件扫描,可能较慢。
|
||||
|
||||
Args:
|
||||
symbol: 符号ID
|
||||
kind: 过滤源类型 ("csharp", "xml", "all")
|
||||
"""
|
||||
logging.info(f"get_used_by: {symbol}")
|
||||
querier = GraphQuerier()
|
||||
edges = querier.query_used_by(symbol, kind)
|
||||
|
||||
return {
|
||||
"targetSymbol": symbol,
|
||||
"edges": edges[:max_results],
|
||||
"total": len(edges)
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.info("RimWorld Code RAG Python Server Starting...")
|
||||
# 预热索引
|
||||
global_index.initialize()
|
||||
mcp.run()
|
||||
Reference in New Issue
Block a user