Files
WulaFallenEmpireRW/Tools/convert_to_cpt.py
2025-12-25 12:52:30 +08:00

220 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
RimWorld 代码转 CPT 训练数据脚本
将 C# 源码和 XML Defs 转换为 JSONL 格式用于继续预训练
支持智能分割超长文件
"""
import os
import json
import re
import argparse
from pathlib import Path
from typing import Generator, List
# 排除的噪声文件Unity 自动生成、二进制数据等)
EXCLUDED_FILES = {
'UnitySourceGeneratedAssemblyMonoScriptTypes_v1.txt',
'__JobReflectionRegistrationOutput__',
'AssemblyInfo.cs',
'.csproj',
}
def should_exclude(file_path: Path) -> bool:
"""检查文件是否应该被排除"""
name = file_path.name
for pattern in EXCLUDED_FILES:
if pattern in name:
return True
return False
def read_file_content(file_path: Path) -> str:
"""读取文件内容,处理多种编码"""
encodings = ['utf-8', 'utf-8-sig', 'gbk', 'latin-1']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
return None
def split_csharp_content(content: str, max_length: int) -> List[str]:
"""智能分割 C# 代码,尽量在类/方法边界分割"""
if len(content) <= max_length:
return [content]
chunks = []
# 按类定义分割
class_pattern = r'(?=\n(?:public|private|internal|protected)?\s*(?:static\s+)?(?:partial\s+)?class\s+\w+)'
parts = re.split(class_pattern, content)
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) <= max_length:
current_chunk += part
else:
if current_chunk:
chunks.append(current_chunk.strip())
# 如果单个 part 仍然太长,按行分割
if len(part) > max_length:
lines = part.split('\n')
current_chunk = ""
for line in lines:
if len(current_chunk) + len(line) + 1 <= max_length:
current_chunk += line + '\n'
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = line + '\n'
else:
current_chunk = part
if current_chunk:
chunks.append(current_chunk.strip())
return [c for c in chunks if len(c) > 100] # 过滤太短的块
def split_xml_content(content: str, max_length: int) -> List[str]:
"""智能分割 XML 内容,尽量在 Def 边界分割"""
if len(content) <= max_length:
return [content]
chunks = []
# 按顶级 Def 元素分割
def_pattern = r'(?=\n\s*<[A-Z][a-zA-Z]*Def\s)'
parts = re.split(def_pattern, content)
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) <= max_length:
current_chunk += part
else:
if current_chunk:
chunks.append(current_chunk.strip())
if len(part) > max_length:
# 强制按长度分割
for i in range(0, len(part), max_length - 200):
chunk = part[i:i + max_length - 200]
if len(chunk) > 100:
chunks.append(chunk.strip())
current_chunk = ""
else:
current_chunk = part
if current_chunk:
chunks.append(current_chunk.strip())
return [c for c in chunks if len(c) > 100]
def create_csharp_entries(file_path: Path, content: str, max_length: int) -> List[dict]:
"""创建 C# 代码的训练条目(支持分割)"""
relative_path = file_path.name
chunks = split_csharp_content(content, max_length - 50) # 留空间给 header
entries = []
for i, chunk in enumerate(chunks):
if len(chunks) > 1:
header = f"// File: {relative_path} (Part {i+1}/{len(chunks)})\n"
else:
header = f"// File: {relative_path}\n"
entries.append({"text": header + chunk})
return entries
def create_xml_entries(file_path: Path, content: str, max_length: int) -> List[dict]:
"""创建 XML 的训练条目(支持分割)"""
relative_path = file_path.name
chunks = split_xml_content(content, max_length - 50)
entries = []
for i, chunk in enumerate(chunks):
if len(chunks) > 1:
header = f"<!-- File: {relative_path} (Part {i+1}/{len(chunks)}) -->\n"
else:
header = f"<!-- File: {relative_path} -->\n"
entries.append({"text": header + chunk})
return entries
def process_csharp_files(source_dir: Path, max_length: int) -> Generator[dict, None, None]:
"""处理所有 C# 文件"""
for ext in ['*.cs', '*.txt']:
for file_path in source_dir.rglob(ext):
if file_path.stat().st_size < 100:
continue
if should_exclude(file_path):
continue
content = read_file_content(file_path)
if not content:
continue
if ext == '*.txt' and not ('class ' in content or 'namespace ' in content or 'public ' in content):
continue
for entry in create_csharp_entries(file_path, content, max_length):
yield entry
def process_xml_files(source_dir: Path, max_length: int) -> Generator[dict, None, None]:
"""处理所有 XML 文件"""
for file_path in source_dir.rglob('*.xml'):
if file_path.stat().st_size < 100:
continue
content = read_file_content(file_path)
if content:
for entry in create_xml_entries(file_path, content, max_length):
yield entry
def main():
parser = argparse.ArgumentParser(description='转换 RimWorld 代码为 CPT 训练数据')
parser.add_argument('--csharp-dir', type=str, help='C# 反编译代码目录')
parser.add_argument('--xml-dir', type=str, help='XML Defs 目录')
parser.add_argument('--output', type=str, default='rimworld_cpt_data.jsonl', help='输出文件路径')
parser.add_argument('--max-length', type=int, default=8000, help='单条数据最大字符数')
args = parser.parse_args()
entries = []
stats = {'csharp': 0, 'xml': 0, 'split_chunks': 0}
# 处理 C# 文件
if args.csharp_dir:
csharp_path = Path(args.csharp_dir)
if csharp_path.exists():
print(f"处理 C# 文件: {csharp_path}")
for entry in process_csharp_files(csharp_path, args.max_length):
entries.append(entry)
if "(Part " in entry['text'][:100]:
stats['split_chunks'] += 1
else:
stats['csharp'] += 1
# 处理 XML 文件
if args.xml_dir:
xml_path = Path(args.xml_dir)
if xml_path.exists():
print(f"处理 XML 文件: {xml_path}")
for entry in process_xml_files(xml_path, args.max_length):
entries.append(entry)
if "(Part " in entry['text'][:100]:
stats['split_chunks'] += 1
else:
stats['xml'] += 1
# 写入输出文件
output_path = Path(args.output)
with open(output_path, 'w', encoding='utf-8') as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 统计信息
total_size = output_path.stat().st_size / (1024 * 1024)
print(f"\n=== 转换完成 ===")
print(f"C# 完整文件: {stats['csharp']}")
print(f"XML 完整文件: {stats['xml']}")
print(f"分割产生的块: {stats['split_chunks']}")
print(f"总条目数: {len(entries)}")
print(f"输出文件: {output_path}")
print(f"文件大小: {total_size:.2f} MB")
if __name__ == '__main__':
main()