diff --git a/.gitignore b/.gitignore index a381cf08..83b7d47d 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ node_modules/ gemini-websocket-proxy/ Tools/dark-server/dark-server.js Tools/rimworld_cpt_data.jsonl +Tools/mem0-1.0.0/ diff --git a/1.6/1.6/Assemblies/WulaFallenEmpire.dll b/1.6/1.6/Assemblies/WulaFallenEmpire.dll index dc57b138..1a46de05 100644 Binary files a/1.6/1.6/Assemblies/WulaFallenEmpire.dll and b/1.6/1.6/Assemblies/WulaFallenEmpire.dll differ diff --git a/1.6/1.6/Assemblies/WulaFallenEmpire.pdb b/1.6/1.6/Assemblies/WulaFallenEmpire.pdb index 899a481b..897cd330 100644 Binary files a/1.6/1.6/Assemblies/WulaFallenEmpire.pdb and b/1.6/1.6/Assemblies/WulaFallenEmpire.pdb differ diff --git a/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryEntry.cs b/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryEntry.cs new file mode 100644 index 00000000..34724ee2 --- /dev/null +++ b/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryEntry.cs @@ -0,0 +1,69 @@ +using System; + +namespace WulaFallenEmpire.EventSystem.AI +{ + /// + /// Represents a single memory entry extracted from conversations. + /// Inspired by Mem0's memory structure. + /// + public class AIMemoryEntry + { + /// Unique identifier for this memory + public string Id { get; set; } + + /// The actual memory content/fact + public string Fact { get; set; } + + /// + /// Category of memory: preference, personal, plan, colony, misc + /// + public string Category { get; set; } + + /// Game ticks when this memory was created + public long CreatedTicks { get; set; } + + /// Game ticks when this memory was last updated + public long UpdatedTicks { get; set; } + + /// Number of times this memory has been accessed/retrieved + public int AccessCount { get; set; } + + /// Hash of the fact for quick duplicate detection + public string Hash { get; set; } + + public AIMemoryEntry() + { + Id = Guid.NewGuid().ToString("N").Substring(0, 12); + CreatedTicks = 0; + UpdatedTicks = 0; + AccessCount = 0; + Category = "misc"; + } + + public AIMemoryEntry(string fact, string category = "misc") : this() + { + Fact = fact; + Category = category ?? "misc"; + Hash = ComputeHash(fact); + } + + public static string ComputeHash(string text) + { + if (string.IsNullOrEmpty(text)) return ""; + // Simple hash based on normalized text + string normalized = text.ToLowerInvariant().Trim(); + return normalized.GetHashCode().ToString("X8"); + } + + public void UpdateFact(string newFact) + { + Fact = newFact; + Hash = ComputeHash(newFact); + } + + public void MarkAccessed() + { + AccessCount++; + } + } +} diff --git a/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryManager.cs b/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryManager.cs new file mode 100644 index 00000000..b3635105 --- /dev/null +++ b/Source/WulaFallenEmpire/EventSystem/AI/AIMemoryManager.cs @@ -0,0 +1,484 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Text; +using RimWorld.Planet; +using Verse; + +namespace WulaFallenEmpire.EventSystem.AI +{ + public class AIMemoryManager : WorldComponent + { + private const string MemoryFolderName = "WulaAIMemory"; + private const string MemoryVersion = "1.0"; + private const int RecencyTickWindow = 60000; + private string _saveId; + private List _memories = new List(); + private bool _loaded; + + public AIMemoryManager(World world) : base(world) + { + } + + public IReadOnlyList GetAllMemories() + { + EnsureLoaded(); + return _memories.ToList(); + } + + public AIMemoryEntry AddMemory(string fact, string category = "misc") + { + EnsureLoaded(); + if (string.IsNullOrWhiteSpace(fact)) return null; + + string normalizedCategory = NormalizeCategory(category); + string hash = AIMemoryEntry.ComputeHash(fact); + string normalizedFact = NormalizeFact(fact); + var existing = _memories.FirstOrDefault(m => m != null && + (string.Equals(m.Hash, hash, StringComparison.OrdinalIgnoreCase) || + string.Equals(NormalizeFact(m.Fact), normalizedFact, StringComparison.Ordinal))); + long now = GetCurrentTicks(); + if (existing != null) + { + existing.UpdateFact(fact); + existing.Category = normalizedCategory; + existing.UpdatedTicks = now; + SaveToFile(); + return existing; + } + + var entry = new AIMemoryEntry(fact, normalizedCategory) + { + CreatedTicks = now, + UpdatedTicks = now, + AccessCount = 0 + }; + _memories.Add(entry); + SaveToFile(); + return entry; + } + + public bool UpdateMemory(string id, string newFact, string category = null) + { + EnsureLoaded(); + if (string.IsNullOrWhiteSpace(id)) return false; + + var entry = _memories.FirstOrDefault(m => string.Equals(m.Id, id, StringComparison.OrdinalIgnoreCase)); + if (entry == null) return false; + + if (!string.IsNullOrWhiteSpace(newFact)) + { + entry.UpdateFact(newFact); + } + + if (!string.IsNullOrWhiteSpace(category)) + { + entry.Category = NormalizeCategory(category); + } + + entry.UpdatedTicks = GetCurrentTicks(); + SaveToFile(); + return true; + } + + public bool DeleteMemory(string id) + { + EnsureLoaded(); + if (string.IsNullOrWhiteSpace(id)) return false; + + int removed = _memories.RemoveAll(m => string.Equals(m.Id, id, StringComparison.OrdinalIgnoreCase)); + if (removed > 0) + { + SaveToFile(); + return true; + } + return false; + } + + public List SearchMemories(string query, int limit = 5) + { + EnsureLoaded(); + if (string.IsNullOrWhiteSpace(query)) return new List(); + + string normalizedQuery = query.Trim(); + List tokens = Tokenize(normalizedQuery); + + long now = GetCurrentTicks(); + var scored = new List<(AIMemoryEntry entry, float score)>(); + + foreach (var entry in _memories) + { + if (entry == null || string.IsNullOrWhiteSpace(entry.Fact)) continue; + float score = ComputeScore(entry, normalizedQuery, tokens, now); + if (score <= 0f) continue; + scored.Add((entry, score)); + } + + var results = scored + .OrderByDescending(s => s.score) + .ThenByDescending(s => s.entry.UpdatedTicks) + .Take(Math.Max(1, limit)) + .Select(s => s.entry) + .ToList(); + + if (results.Count > 0) + { + foreach (var entry in results) + { + entry.MarkAccessed(); + entry.UpdatedTicks = now; + } + SaveToFile(); + } + + return results; + } + + public List GetRecentMemories(int limit = 5) + { + EnsureLoaded(); + return _memories + .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Fact)) + .OrderByDescending(m => m.UpdatedTicks) + .ThenByDescending(m => m.CreatedTicks) + .Take(Math.Max(1, limit)) + .ToList(); + } + + private void EnsureLoaded() + { + if (_loaded) return; + LoadFromFile(); + _loaded = true; + } + + private string GetSaveDirectory() + { + string path = Path.Combine(GenFilePaths.SaveDataFolderPath, MemoryFolderName); + if (!Directory.Exists(path)) + { + Directory.CreateDirectory(path); + } + return path; + } + + private string GetFilePath() + { + if (string.IsNullOrEmpty(_saveId)) + { + _saveId = Guid.NewGuid().ToString("N"); + } + return Path.Combine(GetSaveDirectory(), $"{_saveId}.json"); + } + + private void LoadFromFile() + { + _memories = new List(); + + string path = GetFilePath(); + if (!File.Exists(path)) return; + + try + { + string json = File.ReadAllText(path); + if (string.IsNullOrWhiteSpace(json)) return; + + string array = ExtractJsonArray(json, "memories"); + if (string.IsNullOrWhiteSpace(array)) return; + + foreach (string obj in ExtractJsonObjects(array)) + { + var dict = SimpleJsonParser.Parse(obj); + if (dict == null || dict.Count == 0) continue; + + var entry = new AIMemoryEntry(); + if (dict.TryGetValue("id", out string id) && !string.IsNullOrWhiteSpace(id)) entry.Id = id; + if (dict.TryGetValue("fact", out string fact)) entry.Fact = fact; + if (dict.TryGetValue("category", out string category)) entry.Category = NormalizeCategory(category); + if (dict.TryGetValue("createdTicks", out string created) && long.TryParse(created, NumberStyles.Integer, CultureInfo.InvariantCulture, out long createdTicks)) entry.CreatedTicks = createdTicks; + if (dict.TryGetValue("updatedTicks", out string updated) && long.TryParse(updated, NumberStyles.Integer, CultureInfo.InvariantCulture, out long updatedTicks)) entry.UpdatedTicks = updatedTicks; + if (dict.TryGetValue("accessCount", out string access) && int.TryParse(access, NumberStyles.Integer, CultureInfo.InvariantCulture, out int accessCount)) entry.AccessCount = accessCount; + if (dict.TryGetValue("hash", out string hash)) entry.Hash = hash; + if (string.IsNullOrWhiteSpace(entry.Hash)) + { + entry.Hash = AIMemoryEntry.ComputeHash(entry.Fact); + } + if (string.IsNullOrWhiteSpace(entry.Category)) entry.Category = "misc"; + _memories.Add(entry); + } + } + catch (Exception ex) + { + WulaLog.Debug($"[WulaAI] Failed to load memory file: {ex}"); + } + } + + private void SaveToFile() + { + string path = GetFilePath(); + try + { + StringBuilder sb = new StringBuilder(); + sb.Append("{"); + sb.Append("\"version\":\"").Append(MemoryVersion).Append("\","); + sb.Append("\"memories\":["); + bool first = true; + foreach (var memory in _memories) + { + if (memory == null) continue; + if (!first) sb.Append(","); + first = false; + sb.Append("{"); + sb.Append("\"id\":\"").Append(EscapeJson(memory.Id)).Append("\","); + sb.Append("\"fact\":\"").Append(EscapeJson(memory.Fact)).Append("\","); + sb.Append("\"category\":\"").Append(EscapeJson(memory.Category)).Append("\","); + sb.Append("\"createdTicks\":").Append(memory.CreatedTicks.ToString(CultureInfo.InvariantCulture)).Append(","); + sb.Append("\"updatedTicks\":").Append(memory.UpdatedTicks.ToString(CultureInfo.InvariantCulture)).Append(","); + sb.Append("\"accessCount\":").Append(memory.AccessCount.ToString(CultureInfo.InvariantCulture)).Append(","); + sb.Append("\"hash\":\"").Append(EscapeJson(memory.Hash)).Append("\""); + sb.Append("}"); + } + sb.Append("]}"); + File.WriteAllText(path, sb.ToString()); + } + catch (Exception ex) + { + WulaLog.Debug($"[WulaAI] Failed to save memory file: {ex}"); + } + } + + public override void ExposeData() + { + base.ExposeData(); + Scribe_Values.Look(ref _saveId, "WulaAIMemoryId"); + + if (Scribe.mode == LoadSaveMode.PostLoadInit && string.IsNullOrEmpty(_saveId)) + { + _saveId = Guid.NewGuid().ToString("N"); + _loaded = false; + } + } + + private static long GetCurrentTicks() + { + return Find.TickManager?.TicksGame ?? 0; + } + + private static string NormalizeCategory(string category) + { + if (string.IsNullOrWhiteSpace(category)) return "misc"; + string lower = category.Trim().ToLowerInvariant(); + switch (lower) + { + case "preference": + case "personal": + case "plan": + case "colony": + case "misc": + return lower; + default: + return "misc"; + } + } + + private static string NormalizeFact(string fact) + { + return string.IsNullOrWhiteSpace(fact) ? "" : fact.Trim().ToLowerInvariant(); + } + + private static float ComputeScore(AIMemoryEntry entry, string query, List tokens, long now) + { + string fact = entry.Fact ?? ""; + if (string.IsNullOrWhiteSpace(fact)) return 0f; + + string factLower = fact.ToLowerInvariant(); + string queryLower = query.ToLowerInvariant(); + + float score = 0f; + if (string.Equals(factLower, queryLower, StringComparison.OrdinalIgnoreCase)) + { + score = 1.2f; + } + else if (factLower.Contains(queryLower) || queryLower.Contains(factLower)) + { + score = 0.9f; + } + + if (tokens.Count > 0) + { + int matches = 0; + foreach (string token in tokens) + { + if (string.IsNullOrWhiteSpace(token)) continue; + if (factLower.Contains(token)) matches++; + } + float coverage = matches / (float)Math.Max(1, tokens.Count); + score = Math.Max(score, 0.3f * coverage); + } + + long updated = entry.UpdatedTicks > 0 ? entry.UpdatedTicks : entry.CreatedTicks; + long age = Math.Max(0, now - updated); + float recency = 1f / (1f + (age / (float)RecencyTickWindow)); + float accessBoost = 1f + Math.Min(0.2f, entry.AccessCount * 0.02f); + return score * recency * accessBoost; + } + + private static List Tokenize(string text) + { + var tokens = new List(); + if (string.IsNullOrWhiteSpace(text)) return tokens; + + var sb = new StringBuilder(); + foreach (char c in text) + { + if (char.IsLetterOrDigit(c)) + { + sb.Append(char.ToLowerInvariant(c)); + } + else + { + if (sb.Length > 0) + { + tokens.Add(sb.ToString()); + sb.Length = 0; + } + } + } + if (sb.Length > 0) tokens.Add(sb.ToString()); + return tokens; + } + + private static string EscapeJson(string value) + { + if (string.IsNullOrEmpty(value)) return ""; + return value.Replace("\\", "\\\\").Replace("\"", "\\\"").Replace("\n", "\\n").Replace("\r", "\\r"); + } + + private static string ExtractJsonArray(string json, string key) + { + if (string.IsNullOrWhiteSpace(json) || string.IsNullOrWhiteSpace(key)) return null; + + string keyPattern = $"\"{key}\""; + int keyIndex = json.IndexOf(keyPattern, StringComparison.OrdinalIgnoreCase); + if (keyIndex == -1) return null; + + int arrayStart = json.IndexOf('[', keyIndex); + if (arrayStart == -1) return null; + + int arrayEnd = FindMatchingBracket(json, arrayStart); + if (arrayEnd == -1) return null; + + return json.Substring(arrayStart + 1, arrayEnd - arrayStart - 1); + } + + private static List ExtractJsonObjects(string arrayContent) + { + var objects = new List(); + if (string.IsNullOrWhiteSpace(arrayContent)) return objects; + + int depth = 0; + int start = -1; + bool inString = false; + bool escaped = false; + + for (int i = 0; i < arrayContent.Length; i++) + { + char c = arrayContent[i]; + if (inString) + { + if (escaped) + { + escaped = false; + continue; + } + if (c == '\\') + { + escaped = true; + continue; + } + if (c == '"') + { + inString = false; + } + continue; + } + + if (c == '"') + { + inString = true; + continue; + } + + if (c == '{') + { + if (depth == 0) start = i; + depth++; + continue; + } + if (c == '}') + { + depth--; + if (depth == 0 && start >= 0) + { + objects.Add(arrayContent.Substring(start, i - start + 1)); + start = -1; + } + } + } + + return objects; + } + + private static int FindMatchingBracket(string json, int startIndex) + { + int depth = 0; + bool inString = false; + bool escaped = false; + + for (int i = startIndex; i < json.Length; i++) + { + char c = json[i]; + if (inString) + { + if (escaped) + { + escaped = false; + continue; + } + if (c == '\\') + { + escaped = true; + continue; + } + if (c == '"') + { + inString = false; + } + continue; + } + + if (c == '"') + { + inString = true; + continue; + } + + if (c == '[') + { + depth++; + continue; + } + + if (c == ']') + { + depth--; + if (depth == 0) return i; + } + } + + return -1; + } + } +} diff --git a/Source/WulaFallenEmpire/EventSystem/AI/MemoryPrompts.cs b/Source/WulaFallenEmpire/EventSystem/AI/MemoryPrompts.cs new file mode 100644 index 00000000..105c406f --- /dev/null +++ b/Source/WulaFallenEmpire/EventSystem/AI/MemoryPrompts.cs @@ -0,0 +1,45 @@ +using System.Globalization; + +namespace WulaFallenEmpire.EventSystem.AI +{ + public static class MemoryPrompts + { + public const string FactExtractionPrompt = +@"You are extracting long-term memory about the player from the conversation below. +Return JSON only, no extra text. +Schema: +{{""facts"":[{{""text"":""..."",""category"":""preference|personal|plan|colony|misc""}}]}} +Rules: +- Keep only stable, reusable facts about the player or colony. +- Ignore transient tool results, numbers, or one-off actions. +- Do not invent facts. +Conversation: +{0}"; + + public const string MemoryUpdatePrompt = +@"You are updating a memory store. +Given existing memories and new facts, decide ADD, UPDATE, DELETE, or NONE. +Return JSON only, no extra text. +Schema: +{{""memory"":[{{""id"":""..."",""text"":""..."",""category"":""preference|personal|plan|colony|misc"",""event"":""ADD|UPDATE|DELETE|NONE""}}]}} +Rules: +- UPDATE if a new fact refines or corrects an existing memory. +- DELETE if a memory is contradicted by new facts. +- ADD for genuinely new information. +- NONE if no change is needed. +Existing memories (JSON): +{0} +New facts (JSON): +{1}"; + + public static string BuildFactExtractionPrompt(string conversation) + { + return string.Format(CultureInfo.InvariantCulture, FactExtractionPrompt, conversation ?? ""); + } + + public static string BuildMemoryUpdatePrompt(string existingMemoriesJson, string newFactsJson) + { + return string.Format(CultureInfo.InvariantCulture, MemoryUpdatePrompt, existingMemoriesJson ?? "[]", newFactsJson ?? "[]"); + } + } +} diff --git a/Source/WulaFallenEmpire/EventSystem/AI/UI/Dialog_AIConversation.cs b/Source/WulaFallenEmpire/EventSystem/AI/UI/Dialog_AIConversation.cs index 996b673f..960ae8b1 100644 --- a/Source/WulaFallenEmpire/EventSystem/AI/UI/Dialog_AIConversation.cs +++ b/Source/WulaFallenEmpire/EventSystem/AI/UI/Dialog_AIConversation.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using RimWorld; using UnityEngine; using Verse; +using WulaFallenEmpire.EventSystem.AI; using WulaFallenEmpire.EventSystem.AI.Tools; using System.Text.RegularExpressions; @@ -44,6 +45,13 @@ namespace WulaFallenEmpire.EventSystem.AI.UI private const int DefaultMaxHistoryTokens = 100000; private const int CharsPerToken = 4; private const int ThinkingPhaseTotal = 3; + private const int MemorySearchLimit = 6; + private const int MemoryFactMaxChars = 200; + private const int MemoryPromptMaxChars = 1600; + private const int MemoryUpdateMaxMemories = 40; + private string _memoryContextCache = ""; + private string _memoryContextQuery = ""; + private bool _memoryExtractionInFlight = false; private enum RequestPhase { @@ -59,6 +67,20 @@ namespace WulaFallenEmpire.EventSystem.AI.UI public bool AnyActionError; } + private struct MemoryFact + { + public string Text; + public string Category; + } + + private struct MemoryUpdate + { + public string Event; + public string Id; + public string Text; + public string Category; + } + private void SetThinkingPhase(int phaseIndex, bool isRetry) { _thinkingPhaseIndex = Math.Max(1, Math.Min(ThinkingPhaseTotal, phaseIndex)); @@ -233,10 +255,15 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori { // Use XML persona if available, otherwise default string persona = !string.IsNullOrEmpty(def.aiSystemInstruction) ? def.aiSystemInstruction : DefaultPersona; - + + string memoryContext = _memoryContextCache; + string personaWithMemory = string.IsNullOrWhiteSpace(memoryContext) + ? persona + : persona + "\n\n" + memoryContext; + string fullInstruction = toolsEnabled - ? (persona + "\n" + ToolRulesInstruction + "\n" + toolsForThisPhase) - : persona; + ? (personaWithMemory + "\n" + ToolRulesInstruction + "\n" + toolsForThisPhase) + : personaWithMemory; string language = LanguageDatabase.activeLanguage.FriendlyNameNative; var eventVarManager = Find.World.GetComponent(); @@ -494,6 +521,542 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori return context; } + private void PrepareMemoryContext() + { + string lastUserMessage = _history.LastOrDefault(entry => string.Equals(entry.role, "user", StringComparison.OrdinalIgnoreCase)).message; + if (string.IsNullOrWhiteSpace(lastUserMessage)) + { + _memoryContextCache = ""; + _memoryContextQuery = ""; + return; + } + + _memoryContextQuery = lastUserMessage; + _memoryContextCache = BuildMemoryContext(lastUserMessage); + } + + private string BuildMemoryContext(string userQuery) + { + if (string.IsNullOrWhiteSpace(userQuery)) return ""; + + var memoryManager = Find.World?.GetComponent(); + if (memoryManager == null) return ""; + + var memories = memoryManager.SearchMemories(userQuery, MemorySearchLimit); + if (memories == null || memories.Count == 0) return ""; + + StringBuilder sb = new StringBuilder(); + sb.AppendLine("# MEMORY (Relevant Facts)"); + foreach (var memory in memories) + { + if (memory == null || string.IsNullOrWhiteSpace(memory.Fact)) continue; + string category = string.IsNullOrWhiteSpace(memory.Category) ? "misc" : memory.Category.Trim(); + string fact = TrimMemoryFact(memory.Fact, MemoryFactMaxChars); + if (string.IsNullOrWhiteSpace(fact)) continue; + sb.AppendLine($"- [{category}] {fact}"); + } + + return sb.ToString().TrimEnd(); + } + + private static string TrimMemoryFact(string fact, int maxChars) + { + if (string.IsNullOrWhiteSpace(fact)) return ""; + string trimmed = fact.Trim(); + if (trimmed.Length <= maxChars) return trimmed; + return trimmed.Substring(0, maxChars) + "..."; + } + + private string BuildMemoryExtractionConversation() + { + if (_history == null || _history.Count == 0) return ""; + + const int maxTurns = 6; + var lines = new List(); + + for (int i = _history.Count - 1; i >= 0 && lines.Count < maxTurns; i--) + { + var entry = _history[i]; + if (!string.Equals(entry.role, "user", StringComparison.OrdinalIgnoreCase) && + !string.Equals(entry.role, "assistant", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + string content = entry.role == "assistant" + ? ParseResponseForDisplay(entry.message) + : entry.message; + + if (string.IsNullOrWhiteSpace(content)) continue; + + lines.Add($"{entry.role}: {content.Trim()}"); + } + + lines.Reverse(); + string snippet = string.Join("\n", lines); + return TrimForPrompt(snippet, MemoryPromptMaxChars); + } + + private async Task ExtractAndUpdateMemoriesAsync() + { + if (_memoryExtractionInFlight) return; + _memoryExtractionInFlight = true; + try + { + var world = Find.World; + if (world == null) return; + + var memoryManager = world.GetComponent(); + if (memoryManager == null) return; + + string conversation = BuildMemoryExtractionConversation(); + if (string.IsNullOrWhiteSpace(conversation)) return; + + var settings = WulaFallenEmpire.WulaFallenEmpireMod.settings; + if (settings == null) return; + + var client = new SimpleAIClient(settings.apiKey, settings.baseUrl, settings.model); + string extractPrompt = MemoryPrompts.BuildFactExtractionPrompt(conversation); + string extractResponse = await client.GetChatCompletionAsync(extractPrompt, new List<(string role, string message)>(), maxTokens: 256, temperature: 0.2f); + if (string.IsNullOrWhiteSpace(extractResponse)) return; + + List facts = ParseFactsResponse(extractResponse); + if (facts.Count == 0) return; + + var existing = memoryManager.GetAllMemories() + .OrderByDescending(m => m.UpdatedTicks) + .ThenByDescending(m => m.CreatedTicks) + .Take(MemoryUpdateMaxMemories) + .ToList(); + + if (existing.Count == 0) + { + AddFactsToMemory(memoryManager, facts); + return; + } + + string existingJson = BuildMemoriesJson(existing); + string factsJson = BuildFactsJson(facts); + string updatePrompt = MemoryPrompts.BuildMemoryUpdatePrompt(existingJson, factsJson); + string updateResponse = await client.GetChatCompletionAsync(updatePrompt, new List<(string role, string message)>(), maxTokens: 256, temperature: 0.2f); + if (string.IsNullOrWhiteSpace(updateResponse)) + { + AddFactsToMemory(memoryManager, facts); + return; + } + + List updates = ParseMemoryUpdateResponse(updateResponse); + if (updates.Count == 0) + { + AddFactsToMemory(memoryManager, facts); + return; + } + + ApplyMemoryUpdates(memoryManager, updates, facts); + } + catch (Exception ex) + { + WulaLog.Debug($"[WulaAI] Memory extraction failed: {ex}"); + } + finally + { + _memoryExtractionInFlight = false; + } + } + + private static void AddFactsToMemory(AIMemoryManager memoryManager, List facts) + { + if (memoryManager == null || facts == null) return; + + foreach (var fact in facts) + { + if (string.IsNullOrWhiteSpace(fact.Text)) continue; + memoryManager.AddMemory(fact.Text, fact.Category); + } + } + + private static void ApplyMemoryUpdates(AIMemoryManager memoryManager, List updates, List fallbackFacts) + { + if (memoryManager == null || updates == null) return; + + bool applied = false; + bool anyDecision = false; + foreach (var update in updates) + { + if (string.IsNullOrWhiteSpace(update.Event)) continue; + + switch (update.Event.Trim().ToUpperInvariant()) + { + case "ADD": + anyDecision = true; + if (!string.IsNullOrWhiteSpace(update.Text)) + { + memoryManager.AddMemory(update.Text, update.Category); + applied = true; + } + break; + case "UPDATE": + anyDecision = true; + if (!string.IsNullOrWhiteSpace(update.Id)) + { + if (memoryManager.UpdateMemory(update.Id, update.Text, update.Category)) + { + applied = true; + } + else if (!string.IsNullOrWhiteSpace(update.Text)) + { + memoryManager.AddMemory(update.Text, update.Category); + applied = true; + } + } + else if (!string.IsNullOrWhiteSpace(update.Text)) + { + memoryManager.AddMemory(update.Text, update.Category); + applied = true; + } + break; + case "DELETE": + anyDecision = true; + if (!string.IsNullOrWhiteSpace(update.Id) && memoryManager.DeleteMemory(update.Id)) + { + applied = true; + } + break; + case "NONE": + anyDecision = true; + break; + } + } + + if (!applied && !anyDecision) + { + AddFactsToMemory(memoryManager, fallbackFacts); + } + } + + private static string BuildMemoriesJson(List memories) + { + if (memories == null || memories.Count == 0) return "[]"; + + StringBuilder sb = new StringBuilder(); + sb.Append("["); + bool first = true; + foreach (var memory in memories) + { + if (memory == null || string.IsNullOrWhiteSpace(memory.Fact)) continue; + if (!first) sb.Append(","); + first = false; + sb.Append("{"); + sb.Append("\"id\":\"").Append(EscapeJson(memory.Id)).Append("\","); + sb.Append("\"text\":\"").Append(EscapeJson(memory.Fact)).Append("\","); + sb.Append("\"category\":\"").Append(EscapeJson(memory.Category)).Append("\""); + sb.Append("}"); + } + sb.Append("]"); + return sb.ToString(); + } + + private static string BuildFactsJson(List facts) + { + if (facts == null || facts.Count == 0) return "[]"; + + StringBuilder sb = new StringBuilder(); + sb.Append("["); + bool first = true; + foreach (var fact in facts) + { + if (string.IsNullOrWhiteSpace(fact.Text)) continue; + if (!first) sb.Append(","); + first = false; + sb.Append("{"); + sb.Append("\"text\":\"").Append(EscapeJson(fact.Text)).Append("\","); + sb.Append("\"category\":\"").Append(EscapeJson(fact.Category)).Append("\""); + sb.Append("}"); + } + sb.Append("]"); + return sb.ToString(); + } + + private static List ParseFactsResponse(string response) + { + var facts = new List(); + string array = TryExtractJsonArray(response, "facts"); + if (string.IsNullOrWhiteSpace(array)) return facts; + + var objects = ExtractJsonObjects(array); + if (objects.Count > 0) + { + foreach (string obj in objects) + { + var dict = SimpleJsonParser.Parse(obj); + if (dict == null || dict.Count == 0) continue; + + string text = GetDictionaryValue(dict, "text") ?? GetDictionaryValue(dict, "fact"); + if (string.IsNullOrWhiteSpace(text)) continue; + + string category = NormalizeMemoryCategory(GetDictionaryValue(dict, "category")); + facts.Add(new MemoryFact { Text = text.Trim(), Category = category }); + } + } + else + { + foreach (string item in SplitJsonArrayValues(array)) + { + if (string.IsNullOrWhiteSpace(item)) continue; + facts.Add(new MemoryFact { Text = item.Trim(), Category = "misc" }); + } + } + + var deduped = new List(); + var seen = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var fact in facts) + { + string hash = AIMemoryEntry.ComputeHash(fact.Text); + if (string.IsNullOrWhiteSpace(hash) || !seen.Add(hash)) continue; + deduped.Add(fact); + } + + return deduped; + } + + private static List ParseMemoryUpdateResponse(string response) + { + var updates = new List(); + string array = TryExtractJsonArray(response, "memory") ?? TryExtractJsonArray(response, "memories"); + if (string.IsNullOrWhiteSpace(array)) return updates; + + foreach (string obj in ExtractJsonObjects(array)) + { + var dict = SimpleJsonParser.Parse(obj); + if (dict == null || dict.Count == 0) continue; + + string evt = GetDictionaryValue(dict, "event") ?? GetDictionaryValue(dict, "action"); + if (string.IsNullOrWhiteSpace(evt)) continue; + + updates.Add(new MemoryUpdate + { + Event = evt.Trim().ToUpperInvariant(), + Id = GetDictionaryValue(dict, "id"), + Text = GetDictionaryValue(dict, "text") ?? GetDictionaryValue(dict, "fact"), + Category = NormalizeMemoryCategory(GetDictionaryValue(dict, "category")) + }); + } + + return updates; + } + + private static string NormalizeMemoryCategory(string category) + { + if (string.IsNullOrWhiteSpace(category)) return "misc"; + string normalized = category.Trim().ToLowerInvariant(); + switch (normalized) + { + case "preference": + case "personal": + case "plan": + case "colony": + case "misc": + return normalized; + default: + return "misc"; + } + } + + private static string TryExtractJsonArray(string json, string key) + { + if (string.IsNullOrWhiteSpace(json) || string.IsNullOrWhiteSpace(key)) return null; + + string keyPattern = $"\"{key}\""; + int keyIndex = json.IndexOf(keyPattern, StringComparison.OrdinalIgnoreCase); + if (keyIndex == -1) return null; + + int arrayStart = json.IndexOf('[', keyIndex); + if (arrayStart == -1) return null; + + int arrayEnd = FindMatchingBracket(json, arrayStart); + if (arrayEnd == -1) return null; + + return json.Substring(arrayStart + 1, arrayEnd - arrayStart - 1); + } + + private static List ExtractJsonObjects(string arrayContent) + { + var objects = new List(); + if (string.IsNullOrWhiteSpace(arrayContent)) return objects; + + int depth = 0; + int start = -1; + bool inString = false; + bool escaped = false; + + for (int i = 0; i < arrayContent.Length; i++) + { + char c = arrayContent[i]; + if (inString) + { + if (escaped) + { + escaped = false; + continue; + } + if (c == '\\') + { + escaped = true; + continue; + } + if (c == '"') + { + inString = false; + } + continue; + } + + if (c == '"') + { + inString = true; + continue; + } + + if (c == '{') + { + if (depth == 0) start = i; + depth++; + continue; + } + if (c == '}') + { + depth--; + if (depth == 0 && start >= 0) + { + objects.Add(arrayContent.Substring(start, i - start + 1)); + start = -1; + } + } + } + + return objects; + } + + private static List SplitJsonArrayValues(string arrayContent) + { + var items = new List(); + if (string.IsNullOrWhiteSpace(arrayContent)) return items; + + bool inString = false; + bool escaped = false; + int start = 0; + + for (int i = 0; i < arrayContent.Length; i++) + { + char c = arrayContent[i]; + if (inString) + { + if (escaped) + { + escaped = false; + } + else if (c == '\\') + { + escaped = true; + } + else if (c == '"') + { + inString = false; + } + continue; + } + + if (c == '"') + { + inString = true; + continue; + } + + if (c == ',') + { + string part = arrayContent.Substring(start, i - start); + items.Add(UnescapeJsonString(part.Trim().Trim('"'))); + start = i + 1; + } + } + + if (start < arrayContent.Length) + { + string part = arrayContent.Substring(start); + items.Add(UnescapeJsonString(part.Trim().Trim('"'))); + } + + return items; + } + + private static string UnescapeJsonString(string value) + { + if (string.IsNullOrEmpty(value)) return ""; + return value.Replace("\\r", "\r").Replace("\\n", "\n").Replace("\\\"", "\"").Replace("\\\\", "\\"); + } + + private static string GetDictionaryValue(Dictionary dict, string key) + { + if (dict == null || string.IsNullOrWhiteSpace(key)) return null; + return dict.TryGetValue(key, out string value) ? value : null; + } + + private static string EscapeJson(string value) + { + if (string.IsNullOrEmpty(value)) return ""; + return value.Replace("\\", "\\\\").Replace("\"", "\\\"").Replace("\n", "\\n").Replace("\r", "\\r"); + } + + private static int FindMatchingBracket(string json, int startIndex) + { + int depth = 0; + bool inString = false; + bool escaped = false; + + for (int i = startIndex; i < json.Length; i++) + { + char c = json[i]; + if (inString) + { + if (escaped) + { + escaped = false; + continue; + } + if (c == '\\') + { + escaped = true; + continue; + } + if (c == '"') + { + inString = false; + } + continue; + } + + if (c == '"') + { + inString = true; + continue; + } + + if (c == '[') + { + depth++; + continue; + } + + if (c == ']') + { + depth--; + if (depth == 0) return i; + } + } + + return -1; + } + private async Task RunPhasedRequestAsync() { if (_isThinking) return; @@ -518,6 +1081,7 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori try { + PrepareMemoryContext(); CompressHistoryIfNeeded(); var settings = WulaFallenEmpireMod.settings; @@ -801,6 +1365,7 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori } ParseResponse(reply); + _ = ExtractAndUpdateMemoriesAsync(); } catch (Exception ex) { @@ -1517,6 +2082,8 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori _inputText = ""; _history.Clear(); + _memoryContextCache = ""; + _memoryContextQuery = ""; try { var historyManager = Find.World?.GetComponent();