diff --git a/Source/MCP/mcpserver_stdio.py b/Source/MCP/mcpserver_stdio.py index d78758c6..b7f87a46 100644 --- a/Source/MCP/mcpserver_stdio.py +++ b/Source/MCP/mcpserver_stdio.py @@ -87,23 +87,121 @@ def get_embedding(text: str): logging.error(f"调用向量API时出错: {e}", exc_info=True) raise -def find_most_similar_file(question_embedding, file_embeddings): - """在文件向量中找到与问题向量最相似的一个""" - if not question_embedding or not file_embeddings: - return None - - # 将文件嵌入列表转换为NumPy数组 - file_vectors = np.array([emb['embedding'] for emb in file_embeddings]) - question_vector = np.array(question_embedding).reshape(1, -1) - - # 计算余弦相似度 - similarities = cosine_similarity(question_vector, file_vectors)[0] - - # 找到最相似的文件的索引 - most_similar_index = np.argmax(similarities) - - # 返回最相似的文件路径 - return file_embeddings[most_similar_index]['path'] +def find_most_similar_files(question_embedding, file_embeddings, top_n=3, min_similarity=0.5): + """在文件向量中找到与问题向量最相似的 top_n 个文件。""" + if not question_embedding or not file_embeddings: + return [] + + file_vectors = np.array([emb['embedding'] for emb in file_embeddings]) + question_vector = np.array(question_embedding).reshape(1, -1) + + similarities = cosine_similarity(question_vector, file_vectors)[0] + + # 获取排序后的索引 + sorted_indices = np.argsort(similarities)[::-1] + + # 筛选出最相关的结果 + results = [] + for i in sorted_indices: + similarity_score = similarities[i] + if similarity_score >= min_similarity and len(results) < top_n: + results.append({ + 'path': file_embeddings[i]['path'], + 'similarity': similarity_score + }) + else: + break + + return results + +def extract_relevant_code(file_path, keyword): + """从文件中智能提取包含关键词的完整代码块 (C#类 或 XML Def)。""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + lines = content.split('\n') + keyword_lower = keyword.lower() + + found_line_index = -1 + for i, line in enumerate(lines): + if keyword_lower in line.lower(): + found_line_index = i + break + + if found_line_index == -1: + return "" + + # 根据文件类型选择提取策略 + if file_path.endswith(('.cs', '.txt')): + # C# 提取策略:寻找完整的类 + return extract_csharp_class(lines, found_line_index) + elif file_path.endswith('.xml'): + # XML 提取策略:寻找完整的 Def + return extract_xml_def(lines, found_line_index) + else: + return "" # 不支持的文件类型 + + except Exception as e: + logging.error(f"提取代码时出错 {file_path}: {e}") + return f"# Error reading file: {e}" + +def extract_csharp_class(lines, start_index): + """从C#代码行中提取完整的类定义。""" + # 向上找到 class 声明 + class_start_index = -1 + brace_level_at_class_start = -1 + for i in range(start_index, -1, -1): + line = lines[i] + if 'class ' in line: + class_start_index = i + brace_level_at_class_start = line.count('{') - line.count('}') + break + + if class_start_index == -1: return "" # 没找到类 + + # 从 class 声明开始,向下找到匹配的 '}' + brace_count = brace_level_at_class_start + class_end_index = -1 + for i in range(class_start_index + 1, len(lines)): + line = lines[i] + brace_count += line.count('{') + brace_count -= line.count('}') + if brace_count <= 0: # 找到匹配的闭合括号 + class_end_index = i + break + + if class_end_index != -1: + return "\n".join(lines[class_start_index:class_end_index+1]) + return "" # 未找到完整的类块 + +def extract_xml_def(lines, start_index): + """从XML行中提取完整的Def块。""" + import re + # 向上找到 + def_start_index = -1 + def_tag = "" + for i in range(start_index, -1, -1): + line = lines[i].strip() + match = re.match(r'<(\w+)\s+.*>', line) or re.match(r'<(\w+)>', line) + if match and ('Def' in match.group(1) or 'def' in match.group(1)): + # 这是一个简化的判断,实际中可能需要更复杂的逻辑 + def_start_index = i + def_tag = match.group(1) + break + + if def_start_index == -1: return "" + + # 向下找到匹配的 + def_end_index = -1 + for i in range(def_start_index + 1, len(lines)): + if f'' in lines[i]: + def_end_index = i + break + + if def_end_index != -1: + return "\n".join(lines[def_start_index:def_end_index+1]) + return "" # 5. --- 核心功能函数 --- def find_files_with_keyword(roots, keyword, extensions=['.xml', '.cs', '.txt']): @@ -183,8 +281,8 @@ mcp = FastMCP( @mcp.tool() def get_context(question: str) -> str: """ - 根据问题中的关键词和向量相似度,在RimWorld知识库中搜索最相关的XML或C#文件。 - 返回最匹配的文件路径。 + 根据问题中的关键词和向量相似度,在RimWorld知识库中搜索最相关的多个代码片段, + 并将其整合后返回。 """ logging.info(f"收到问题: {question}") keyword = find_keyword_in_question(question) @@ -196,9 +294,9 @@ def get_context(question: str) -> str: # 1. 检查缓存 if keyword in knowledge_cache: - cached_path = knowledge_cache[keyword] - logging.info(f"缓存命中: 关键词 '{keyword}' -> {cached_path}") - return f"根据知识库缓存,与 '{keyword}' 最相关的定义文件是:\n{cached_path}" + cached_result = knowledge_cache[keyword] + logging.info(f"缓存命中: 关键词 '{keyword}'") + return cached_result logging.info(f"缓存未命中,开始实时搜索: {keyword}") @@ -221,7 +319,6 @@ def get_context(question: str) -> str: try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - # v4模型支持更长的输入 file_embedding = get_embedding(content[:8000]) if file_embedding: file_embeddings.append({'path': file_path, 'embedding': file_embedding}) @@ -231,18 +328,51 @@ def get_context(question: str) -> str: if not file_embeddings: return "无法为任何候选文件生成向量。" - # 找到最相似的文件 - best_match_path = find_most_similar_file(question_embedding, file_embeddings) + # 找到最相似的多个文件 + best_matches = find_most_similar_files(question_embedding, file_embeddings, top_n=3) - if not best_match_path: - return "计算向量相似度失败。" + if not best_matches: + return "计算向量相似度失败或没有找到足够相似的文件。" - # 4. 更新缓存并返回结果 - logging.info(f"向量搜索完成。最匹配的文件是: {best_match_path}") - knowledge_cache[keyword] = best_match_path + # 4. 提取代码并格式化输出 + output_parts = [f"根据向量相似度分析,与 '{keyword}' 最相关的代码定义如下:\n"] + + for match in best_matches: + file_path = match['path'] + similarity = match['similarity'] + + # 智能提取代码块 + code_block = extract_relevant_code(file_path, keyword) + + # 如果提取失败,则跳过这个文件 + if not code_block or code_block.startswith("# Error"): + logging.warning(f"未能从 {file_path} 提取到完整的代码块。") + continue + + # 确定语言类型用于markdown高亮 + lang = "csharp" if file_path.endswith(('.cs', '.txt')) else "xml" + + output_parts.append( + f"---\n" + f"**文件路径:** `{file_path}`\n" + f"**相似度:** {similarity:.4f}\n\n" + f"```{lang}\n" + f"{code_block}\n" + f"```" + ) + + # 如果没有任何代码块被成功提取 + if len(output_parts) <= 1: + return f"虽然找到了相似的文件,但无法在其中提取到关于 '{keyword}' 的完整代码块。" + + final_output = "\n".join(output_parts) + + # 5. 更新缓存并返回结果 + logging.info(f"向量搜索完成。找到了 {len(best_matches)} 个匹配项并成功提取了代码。") + knowledge_cache[keyword] = final_output save_cache(knowledge_cache) - return f"根据向量相似度分析,与 '{keyword}' 最相关的定义文件是:\n{best_match_path}" + return final_output except Exception as e: logging.error(f"处理请求时发生意外错误: {e}", exc_info=True)