mcpv3
This commit is contained in:
@@ -87,23 +87,121 @@ def get_embedding(text: str):
|
||||
logging.error(f"调用向量API时出错: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def find_most_similar_file(question_embedding, file_embeddings):
|
||||
"""在文件向量中找到与问题向量最相似的一个"""
|
||||
if not question_embedding or not file_embeddings:
|
||||
return None
|
||||
|
||||
# 将文件嵌入列表转换为NumPy数组
|
||||
file_vectors = np.array([emb['embedding'] for emb in file_embeddings])
|
||||
question_vector = np.array(question_embedding).reshape(1, -1)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarities = cosine_similarity(question_vector, file_vectors)[0]
|
||||
|
||||
# 找到最相似的文件的索引
|
||||
most_similar_index = np.argmax(similarities)
|
||||
|
||||
# 返回最相似的文件路径
|
||||
return file_embeddings[most_similar_index]['path']
|
||||
def find_most_similar_files(question_embedding, file_embeddings, top_n=3, min_similarity=0.5):
|
||||
"""在文件向量中找到与问题向量最相似的 top_n 个文件。"""
|
||||
if not question_embedding or not file_embeddings:
|
||||
return []
|
||||
|
||||
file_vectors = np.array([emb['embedding'] for emb in file_embeddings])
|
||||
question_vector = np.array(question_embedding).reshape(1, -1)
|
||||
|
||||
similarities = cosine_similarity(question_vector, file_vectors)[0]
|
||||
|
||||
# 获取排序后的索引
|
||||
sorted_indices = np.argsort(similarities)[::-1]
|
||||
|
||||
# 筛选出最相关的结果
|
||||
results = []
|
||||
for i in sorted_indices:
|
||||
similarity_score = similarities[i]
|
||||
if similarity_score >= min_similarity and len(results) < top_n:
|
||||
results.append({
|
||||
'path': file_embeddings[i]['path'],
|
||||
'similarity': similarity_score
|
||||
})
|
||||
else:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def extract_relevant_code(file_path, keyword):
|
||||
"""从文件中智能提取包含关键词的完整代码块 (C#类 或 XML Def)。"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split('\n')
|
||||
keyword_lower = keyword.lower()
|
||||
|
||||
found_line_index = -1
|
||||
for i, line in enumerate(lines):
|
||||
if keyword_lower in line.lower():
|
||||
found_line_index = i
|
||||
break
|
||||
|
||||
if found_line_index == -1:
|
||||
return ""
|
||||
|
||||
# 根据文件类型选择提取策略
|
||||
if file_path.endswith(('.cs', '.txt')):
|
||||
# C# 提取策略:寻找完整的类
|
||||
return extract_csharp_class(lines, found_line_index)
|
||||
elif file_path.endswith('.xml'):
|
||||
# XML 提取策略:寻找完整的 Def
|
||||
return extract_xml_def(lines, found_line_index)
|
||||
else:
|
||||
return "" # 不支持的文件类型
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"提取代码时出错 {file_path}: {e}")
|
||||
return f"# Error reading file: {e}"
|
||||
|
||||
def extract_csharp_class(lines, start_index):
|
||||
"""从C#代码行中提取完整的类定义。"""
|
||||
# 向上找到 class 声明
|
||||
class_start_index = -1
|
||||
brace_level_at_class_start = -1
|
||||
for i in range(start_index, -1, -1):
|
||||
line = lines[i]
|
||||
if 'class ' in line:
|
||||
class_start_index = i
|
||||
brace_level_at_class_start = line.count('{') - line.count('}')
|
||||
break
|
||||
|
||||
if class_start_index == -1: return "" # 没找到类
|
||||
|
||||
# 从 class 声明开始,向下找到匹配的 '}'
|
||||
brace_count = brace_level_at_class_start
|
||||
class_end_index = -1
|
||||
for i in range(class_start_index + 1, len(lines)):
|
||||
line = lines[i]
|
||||
brace_count += line.count('{')
|
||||
brace_count -= line.count('}')
|
||||
if brace_count <= 0: # 找到匹配的闭合括号
|
||||
class_end_index = i
|
||||
break
|
||||
|
||||
if class_end_index != -1:
|
||||
return "\n".join(lines[class_start_index:class_end_index+1])
|
||||
return "" # 未找到完整的类块
|
||||
|
||||
def extract_xml_def(lines, start_index):
|
||||
"""从XML行中提取完整的Def块。"""
|
||||
import re
|
||||
# 向上找到 <DefName> 或 <defName>
|
||||
def_start_index = -1
|
||||
def_tag = ""
|
||||
for i in range(start_index, -1, -1):
|
||||
line = lines[i].strip()
|
||||
match = re.match(r'<(\w+)\s+.*>', line) or re.match(r'<(\w+)>', line)
|
||||
if match and ('Def' in match.group(1) or 'def' in match.group(1)):
|
||||
# 这是一个简化的判断,实际中可能需要更复杂的逻辑
|
||||
def_start_index = i
|
||||
def_tag = match.group(1)
|
||||
break
|
||||
|
||||
if def_start_index == -1: return ""
|
||||
|
||||
# 向下找到匹配的 </DefName>
|
||||
def_end_index = -1
|
||||
for i in range(def_start_index + 1, len(lines)):
|
||||
if f'</{def_tag}>' in lines[i]:
|
||||
def_end_index = i
|
||||
break
|
||||
|
||||
if def_end_index != -1:
|
||||
return "\n".join(lines[def_start_index:def_end_index+1])
|
||||
return ""
|
||||
|
||||
# 5. --- 核心功能函数 ---
|
||||
def find_files_with_keyword(roots, keyword, extensions=['.xml', '.cs', '.txt']):
|
||||
@@ -183,8 +281,8 @@ mcp = FastMCP(
|
||||
@mcp.tool()
|
||||
def get_context(question: str) -> str:
|
||||
"""
|
||||
根据问题中的关键词和向量相似度,在RimWorld知识库中搜索最相关的XML或C#文件。
|
||||
返回最匹配的文件路径。
|
||||
根据问题中的关键词和向量相似度,在RimWorld知识库中搜索最相关的多个代码片段,
|
||||
并将其整合后返回。
|
||||
"""
|
||||
logging.info(f"收到问题: {question}")
|
||||
keyword = find_keyword_in_question(question)
|
||||
@@ -196,9 +294,9 @@ def get_context(question: str) -> str:
|
||||
|
||||
# 1. 检查缓存
|
||||
if keyword in knowledge_cache:
|
||||
cached_path = knowledge_cache[keyword]
|
||||
logging.info(f"缓存命中: 关键词 '{keyword}' -> {cached_path}")
|
||||
return f"根据知识库缓存,与 '{keyword}' 最相关的定义文件是:\n{cached_path}"
|
||||
cached_result = knowledge_cache[keyword]
|
||||
logging.info(f"缓存命中: 关键词 '{keyword}'")
|
||||
return cached_result
|
||||
|
||||
logging.info(f"缓存未命中,开始实时搜索: {keyword}")
|
||||
|
||||
@@ -221,7 +319,6 @@ def get_context(question: str) -> str:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
# v4模型支持更长的输入
|
||||
file_embedding = get_embedding(content[:8000])
|
||||
if file_embedding:
|
||||
file_embeddings.append({'path': file_path, 'embedding': file_embedding})
|
||||
@@ -231,18 +328,51 @@ def get_context(question: str) -> str:
|
||||
if not file_embeddings:
|
||||
return "无法为任何候选文件生成向量。"
|
||||
|
||||
# 找到最相似的文件
|
||||
best_match_path = find_most_similar_file(question_embedding, file_embeddings)
|
||||
# 找到最相似的多个文件
|
||||
best_matches = find_most_similar_files(question_embedding, file_embeddings, top_n=3)
|
||||
|
||||
if not best_match_path:
|
||||
return "计算向量相似度失败。"
|
||||
if not best_matches:
|
||||
return "计算向量相似度失败或没有找到足够相似的文件。"
|
||||
|
||||
# 4. 更新缓存并返回结果
|
||||
logging.info(f"向量搜索完成。最匹配的文件是: {best_match_path}")
|
||||
knowledge_cache[keyword] = best_match_path
|
||||
# 4. 提取代码并格式化输出
|
||||
output_parts = [f"根据向量相似度分析,与 '{keyword}' 最相关的代码定义如下:\n"]
|
||||
|
||||
for match in best_matches:
|
||||
file_path = match['path']
|
||||
similarity = match['similarity']
|
||||
|
||||
# 智能提取代码块
|
||||
code_block = extract_relevant_code(file_path, keyword)
|
||||
|
||||
# 如果提取失败,则跳过这个文件
|
||||
if not code_block or code_block.startswith("# Error"):
|
||||
logging.warning(f"未能从 {file_path} 提取到完整的代码块。")
|
||||
continue
|
||||
|
||||
# 确定语言类型用于markdown高亮
|
||||
lang = "csharp" if file_path.endswith(('.cs', '.txt')) else "xml"
|
||||
|
||||
output_parts.append(
|
||||
f"---\n"
|
||||
f"**文件路径:** `{file_path}`\n"
|
||||
f"**相似度:** {similarity:.4f}\n\n"
|
||||
f"```{lang}\n"
|
||||
f"{code_block}\n"
|
||||
f"```"
|
||||
)
|
||||
|
||||
# 如果没有任何代码块被成功提取
|
||||
if len(output_parts) <= 1:
|
||||
return f"虽然找到了相似的文件,但无法在其中提取到关于 '{keyword}' 的完整代码块。"
|
||||
|
||||
final_output = "\n".join(output_parts)
|
||||
|
||||
# 5. 更新缓存并返回结果
|
||||
logging.info(f"向量搜索完成。找到了 {len(best_matches)} 个匹配项并成功提取了代码。")
|
||||
knowledge_cache[keyword] = final_output
|
||||
save_cache(knowledge_cache)
|
||||
|
||||
return f"根据向量相似度分析,与 '{keyword}' 最相关的定义文件是:\n{best_match_path}"
|
||||
return final_output
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理请求时发生意外错误: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user