This commit is contained in:
2025-12-25 12:52:30 +08:00
parent d73a7eefbd
commit a0c46670e3
3 changed files with 223 additions and 0 deletions

219
Tools/convert_to_cpt.py Normal file
View File

@@ -0,0 +1,219 @@
#!/usr/bin/env python3
"""
RimWorld 代码转 CPT 训练数据脚本
将 C# 源码和 XML Defs 转换为 JSONL 格式用于继续预训练
支持智能分割超长文件
"""
import os
import json
import re
import argparse
from pathlib import Path
from typing import Generator, List
# 排除的噪声文件Unity 自动生成、二进制数据等)
EXCLUDED_FILES = {
'UnitySourceGeneratedAssemblyMonoScriptTypes_v1.txt',
'__JobReflectionRegistrationOutput__',
'AssemblyInfo.cs',
'.csproj',
}
def should_exclude(file_path: Path) -> bool:
"""检查文件是否应该被排除"""
name = file_path.name
for pattern in EXCLUDED_FILES:
if pattern in name:
return True
return False
def read_file_content(file_path: Path) -> str:
"""读取文件内容,处理多种编码"""
encodings = ['utf-8', 'utf-8-sig', 'gbk', 'latin-1']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
return None
def split_csharp_content(content: str, max_length: int) -> List[str]:
"""智能分割 C# 代码,尽量在类/方法边界分割"""
if len(content) <= max_length:
return [content]
chunks = []
# 按类定义分割
class_pattern = r'(?=\n(?:public|private|internal|protected)?\s*(?:static\s+)?(?:partial\s+)?class\s+\w+)'
parts = re.split(class_pattern, content)
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) <= max_length:
current_chunk += part
else:
if current_chunk:
chunks.append(current_chunk.strip())
# 如果单个 part 仍然太长,按行分割
if len(part) > max_length:
lines = part.split('\n')
current_chunk = ""
for line in lines:
if len(current_chunk) + len(line) + 1 <= max_length:
current_chunk += line + '\n'
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = line + '\n'
else:
current_chunk = part
if current_chunk:
chunks.append(current_chunk.strip())
return [c for c in chunks if len(c) > 100] # 过滤太短的块
def split_xml_content(content: str, max_length: int) -> List[str]:
"""智能分割 XML 内容,尽量在 Def 边界分割"""
if len(content) <= max_length:
return [content]
chunks = []
# 按顶级 Def 元素分割
def_pattern = r'(?=\n\s*<[A-Z][a-zA-Z]*Def\s)'
parts = re.split(def_pattern, content)
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) <= max_length:
current_chunk += part
else:
if current_chunk:
chunks.append(current_chunk.strip())
if len(part) > max_length:
# 强制按长度分割
for i in range(0, len(part), max_length - 200):
chunk = part[i:i + max_length - 200]
if len(chunk) > 100:
chunks.append(chunk.strip())
current_chunk = ""
else:
current_chunk = part
if current_chunk:
chunks.append(current_chunk.strip())
return [c for c in chunks if len(c) > 100]
def create_csharp_entries(file_path: Path, content: str, max_length: int) -> List[dict]:
"""创建 C# 代码的训练条目(支持分割)"""
relative_path = file_path.name
chunks = split_csharp_content(content, max_length - 50) # 留空间给 header
entries = []
for i, chunk in enumerate(chunks):
if len(chunks) > 1:
header = f"// File: {relative_path} (Part {i+1}/{len(chunks)})\n"
else:
header = f"// File: {relative_path}\n"
entries.append({"text": header + chunk})
return entries
def create_xml_entries(file_path: Path, content: str, max_length: int) -> List[dict]:
"""创建 XML 的训练条目(支持分割)"""
relative_path = file_path.name
chunks = split_xml_content(content, max_length - 50)
entries = []
for i, chunk in enumerate(chunks):
if len(chunks) > 1:
header = f"<!-- File: {relative_path} (Part {i+1}/{len(chunks)}) -->\n"
else:
header = f"<!-- File: {relative_path} -->\n"
entries.append({"text": header + chunk})
return entries
def process_csharp_files(source_dir: Path, max_length: int) -> Generator[dict, None, None]:
"""处理所有 C# 文件"""
for ext in ['*.cs', '*.txt']:
for file_path in source_dir.rglob(ext):
if file_path.stat().st_size < 100:
continue
if should_exclude(file_path):
continue
content = read_file_content(file_path)
if not content:
continue
if ext == '*.txt' and not ('class ' in content or 'namespace ' in content or 'public ' in content):
continue
for entry in create_csharp_entries(file_path, content, max_length):
yield entry
def process_xml_files(source_dir: Path, max_length: int) -> Generator[dict, None, None]:
"""处理所有 XML 文件"""
for file_path in source_dir.rglob('*.xml'):
if file_path.stat().st_size < 100:
continue
content = read_file_content(file_path)
if content:
for entry in create_xml_entries(file_path, content, max_length):
yield entry
def main():
parser = argparse.ArgumentParser(description='转换 RimWorld 代码为 CPT 训练数据')
parser.add_argument('--csharp-dir', type=str, help='C# 反编译代码目录')
parser.add_argument('--xml-dir', type=str, help='XML Defs 目录')
parser.add_argument('--output', type=str, default='rimworld_cpt_data.jsonl', help='输出文件路径')
parser.add_argument('--max-length', type=int, default=8000, help='单条数据最大字符数')
args = parser.parse_args()
entries = []
stats = {'csharp': 0, 'xml': 0, 'split_chunks': 0}
# 处理 C# 文件
if args.csharp_dir:
csharp_path = Path(args.csharp_dir)
if csharp_path.exists():
print(f"处理 C# 文件: {csharp_path}")
for entry in process_csharp_files(csharp_path, args.max_length):
entries.append(entry)
if "(Part " in entry['text'][:100]:
stats['split_chunks'] += 1
else:
stats['csharp'] += 1
# 处理 XML 文件
if args.xml_dir:
xml_path = Path(args.xml_dir)
if xml_path.exists():
print(f"处理 XML 文件: {xml_path}")
for entry in process_xml_files(xml_path, args.max_length):
entries.append(entry)
if "(Part " in entry['text'][:100]:
stats['split_chunks'] += 1
else:
stats['xml'] += 1
# 写入输出文件
output_path = Path(args.output)
with open(output_path, 'w', encoding='utf-8') as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 统计信息
total_size = output_path.stat().st_size / (1024 * 1024)
print(f"\n=== 转换完成 ===")
print(f"C# 完整文件: {stats['csharp']}")
print(f"XML 完整文件: {stats['xml']}")
print(f"分割产生的块: {stats['split_chunks']}")
print(f"总条目数: {len(entries)}")
print(f"输出文件: {output_path}")
print(f"文件大小: {total_size:.2f} MB")
if __name__ == '__main__':
main()