feat(ai): implement AI memory system with context extraction and storage

Add comprehensive AI memory functionality including:
- Memory context preparation and caching
- Fact extraction from conversations
- Memory search and retrieval with category support
- JSON-based memory update operations (add/update/delete)
- Memory deduplication and normalization
- Integration with AI conversation flow

The system extracts relevant facts from conversations, stores them with
categories, and provides contextual memory to enhance AI responses.
This commit is contained in:
2025-12-27 12:41:39 +08:00
parent e0eb790346
commit 971675f6ea
7 changed files with 1169 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ using System.Threading.Tasks;
using RimWorld;
using UnityEngine;
using Verse;
using WulaFallenEmpire.EventSystem.AI;
using WulaFallenEmpire.EventSystem.AI.Tools;
using System.Text.RegularExpressions;
@@ -44,6 +45,13 @@ namespace WulaFallenEmpire.EventSystem.AI.UI
private const int DefaultMaxHistoryTokens = 100000;
private const int CharsPerToken = 4;
private const int ThinkingPhaseTotal = 3;
private const int MemorySearchLimit = 6;
private const int MemoryFactMaxChars = 200;
private const int MemoryPromptMaxChars = 1600;
private const int MemoryUpdateMaxMemories = 40;
private string _memoryContextCache = "";
private string _memoryContextQuery = "";
private bool _memoryExtractionInFlight = false;
private enum RequestPhase
{
@@ -59,6 +67,20 @@ namespace WulaFallenEmpire.EventSystem.AI.UI
public bool AnyActionError;
}
private struct MemoryFact
{
public string Text;
public string Category;
}
private struct MemoryUpdate
{
public string Event;
public string Id;
public string Text;
public string Category;
}
private void SetThinkingPhase(int phaseIndex, bool isRetry)
{
_thinkingPhaseIndex = Math.Max(1, Math.Min(ThinkingPhaseTotal, phaseIndex));
@@ -233,10 +255,15 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori
{
// Use XML persona if available, otherwise default
string persona = !string.IsNullOrEmpty(def.aiSystemInstruction) ? def.aiSystemInstruction : DefaultPersona;
string memoryContext = _memoryContextCache;
string personaWithMemory = string.IsNullOrWhiteSpace(memoryContext)
? persona
: persona + "\n\n" + memoryContext;
string fullInstruction = toolsEnabled
? (persona + "\n" + ToolRulesInstruction + "\n" + toolsForThisPhase)
: persona;
? (personaWithMemory + "\n" + ToolRulesInstruction + "\n" + toolsForThisPhase)
: personaWithMemory;
string language = LanguageDatabase.activeLanguage.FriendlyNameNative;
var eventVarManager = Find.World.GetComponent<EventVariableManager>();
@@ -494,6 +521,542 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori
return context;
}
private void PrepareMemoryContext()
{
string lastUserMessage = _history.LastOrDefault(entry => string.Equals(entry.role, "user", StringComparison.OrdinalIgnoreCase)).message;
if (string.IsNullOrWhiteSpace(lastUserMessage))
{
_memoryContextCache = "";
_memoryContextQuery = "";
return;
}
_memoryContextQuery = lastUserMessage;
_memoryContextCache = BuildMemoryContext(lastUserMessage);
}
private string BuildMemoryContext(string userQuery)
{
if (string.IsNullOrWhiteSpace(userQuery)) return "";
var memoryManager = Find.World?.GetComponent<AIMemoryManager>();
if (memoryManager == null) return "";
var memories = memoryManager.SearchMemories(userQuery, MemorySearchLimit);
if (memories == null || memories.Count == 0) return "";
StringBuilder sb = new StringBuilder();
sb.AppendLine("# MEMORY (Relevant Facts)");
foreach (var memory in memories)
{
if (memory == null || string.IsNullOrWhiteSpace(memory.Fact)) continue;
string category = string.IsNullOrWhiteSpace(memory.Category) ? "misc" : memory.Category.Trim();
string fact = TrimMemoryFact(memory.Fact, MemoryFactMaxChars);
if (string.IsNullOrWhiteSpace(fact)) continue;
sb.AppendLine($"- [{category}] {fact}");
}
return sb.ToString().TrimEnd();
}
private static string TrimMemoryFact(string fact, int maxChars)
{
if (string.IsNullOrWhiteSpace(fact)) return "";
string trimmed = fact.Trim();
if (trimmed.Length <= maxChars) return trimmed;
return trimmed.Substring(0, maxChars) + "...";
}
private string BuildMemoryExtractionConversation()
{
if (_history == null || _history.Count == 0) return "";
const int maxTurns = 6;
var lines = new List<string>();
for (int i = _history.Count - 1; i >= 0 && lines.Count < maxTurns; i--)
{
var entry = _history[i];
if (!string.Equals(entry.role, "user", StringComparison.OrdinalIgnoreCase) &&
!string.Equals(entry.role, "assistant", StringComparison.OrdinalIgnoreCase))
{
continue;
}
string content = entry.role == "assistant"
? ParseResponseForDisplay(entry.message)
: entry.message;
if (string.IsNullOrWhiteSpace(content)) continue;
lines.Add($"{entry.role}: {content.Trim()}");
}
lines.Reverse();
string snippet = string.Join("\n", lines);
return TrimForPrompt(snippet, MemoryPromptMaxChars);
}
private async Task ExtractAndUpdateMemoriesAsync()
{
if (_memoryExtractionInFlight) return;
_memoryExtractionInFlight = true;
try
{
var world = Find.World;
if (world == null) return;
var memoryManager = world.GetComponent<AIMemoryManager>();
if (memoryManager == null) return;
string conversation = BuildMemoryExtractionConversation();
if (string.IsNullOrWhiteSpace(conversation)) return;
var settings = WulaFallenEmpire.WulaFallenEmpireMod.settings;
if (settings == null) return;
var client = new SimpleAIClient(settings.apiKey, settings.baseUrl, settings.model);
string extractPrompt = MemoryPrompts.BuildFactExtractionPrompt(conversation);
string extractResponse = await client.GetChatCompletionAsync(extractPrompt, new List<(string role, string message)>(), maxTokens: 256, temperature: 0.2f);
if (string.IsNullOrWhiteSpace(extractResponse)) return;
List<MemoryFact> facts = ParseFactsResponse(extractResponse);
if (facts.Count == 0) return;
var existing = memoryManager.GetAllMemories()
.OrderByDescending(m => m.UpdatedTicks)
.ThenByDescending(m => m.CreatedTicks)
.Take(MemoryUpdateMaxMemories)
.ToList();
if (existing.Count == 0)
{
AddFactsToMemory(memoryManager, facts);
return;
}
string existingJson = BuildMemoriesJson(existing);
string factsJson = BuildFactsJson(facts);
string updatePrompt = MemoryPrompts.BuildMemoryUpdatePrompt(existingJson, factsJson);
string updateResponse = await client.GetChatCompletionAsync(updatePrompt, new List<(string role, string message)>(), maxTokens: 256, temperature: 0.2f);
if (string.IsNullOrWhiteSpace(updateResponse))
{
AddFactsToMemory(memoryManager, facts);
return;
}
List<MemoryUpdate> updates = ParseMemoryUpdateResponse(updateResponse);
if (updates.Count == 0)
{
AddFactsToMemory(memoryManager, facts);
return;
}
ApplyMemoryUpdates(memoryManager, updates, facts);
}
catch (Exception ex)
{
WulaLog.Debug($"[WulaAI] Memory extraction failed: {ex}");
}
finally
{
_memoryExtractionInFlight = false;
}
}
private static void AddFactsToMemory(AIMemoryManager memoryManager, List<MemoryFact> facts)
{
if (memoryManager == null || facts == null) return;
foreach (var fact in facts)
{
if (string.IsNullOrWhiteSpace(fact.Text)) continue;
memoryManager.AddMemory(fact.Text, fact.Category);
}
}
private static void ApplyMemoryUpdates(AIMemoryManager memoryManager, List<MemoryUpdate> updates, List<MemoryFact> fallbackFacts)
{
if (memoryManager == null || updates == null) return;
bool applied = false;
bool anyDecision = false;
foreach (var update in updates)
{
if (string.IsNullOrWhiteSpace(update.Event)) continue;
switch (update.Event.Trim().ToUpperInvariant())
{
case "ADD":
anyDecision = true;
if (!string.IsNullOrWhiteSpace(update.Text))
{
memoryManager.AddMemory(update.Text, update.Category);
applied = true;
}
break;
case "UPDATE":
anyDecision = true;
if (!string.IsNullOrWhiteSpace(update.Id))
{
if (memoryManager.UpdateMemory(update.Id, update.Text, update.Category))
{
applied = true;
}
else if (!string.IsNullOrWhiteSpace(update.Text))
{
memoryManager.AddMemory(update.Text, update.Category);
applied = true;
}
}
else if (!string.IsNullOrWhiteSpace(update.Text))
{
memoryManager.AddMemory(update.Text, update.Category);
applied = true;
}
break;
case "DELETE":
anyDecision = true;
if (!string.IsNullOrWhiteSpace(update.Id) && memoryManager.DeleteMemory(update.Id))
{
applied = true;
}
break;
case "NONE":
anyDecision = true;
break;
}
}
if (!applied && !anyDecision)
{
AddFactsToMemory(memoryManager, fallbackFacts);
}
}
private static string BuildMemoriesJson(List<AIMemoryEntry> memories)
{
if (memories == null || memories.Count == 0) return "[]";
StringBuilder sb = new StringBuilder();
sb.Append("[");
bool first = true;
foreach (var memory in memories)
{
if (memory == null || string.IsNullOrWhiteSpace(memory.Fact)) continue;
if (!first) sb.Append(",");
first = false;
sb.Append("{");
sb.Append("\"id\":\"").Append(EscapeJson(memory.Id)).Append("\",");
sb.Append("\"text\":\"").Append(EscapeJson(memory.Fact)).Append("\",");
sb.Append("\"category\":\"").Append(EscapeJson(memory.Category)).Append("\"");
sb.Append("}");
}
sb.Append("]");
return sb.ToString();
}
private static string BuildFactsJson(List<MemoryFact> facts)
{
if (facts == null || facts.Count == 0) return "[]";
StringBuilder sb = new StringBuilder();
sb.Append("[");
bool first = true;
foreach (var fact in facts)
{
if (string.IsNullOrWhiteSpace(fact.Text)) continue;
if (!first) sb.Append(",");
first = false;
sb.Append("{");
sb.Append("\"text\":\"").Append(EscapeJson(fact.Text)).Append("\",");
sb.Append("\"category\":\"").Append(EscapeJson(fact.Category)).Append("\"");
sb.Append("}");
}
sb.Append("]");
return sb.ToString();
}
private static List<MemoryFact> ParseFactsResponse(string response)
{
var facts = new List<MemoryFact>();
string array = TryExtractJsonArray(response, "facts");
if (string.IsNullOrWhiteSpace(array)) return facts;
var objects = ExtractJsonObjects(array);
if (objects.Count > 0)
{
foreach (string obj in objects)
{
var dict = SimpleJsonParser.Parse(obj);
if (dict == null || dict.Count == 0) continue;
string text = GetDictionaryValue(dict, "text") ?? GetDictionaryValue(dict, "fact");
if (string.IsNullOrWhiteSpace(text)) continue;
string category = NormalizeMemoryCategory(GetDictionaryValue(dict, "category"));
facts.Add(new MemoryFact { Text = text.Trim(), Category = category });
}
}
else
{
foreach (string item in SplitJsonArrayValues(array))
{
if (string.IsNullOrWhiteSpace(item)) continue;
facts.Add(new MemoryFact { Text = item.Trim(), Category = "misc" });
}
}
var deduped = new List<MemoryFact>();
var seen = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
foreach (var fact in facts)
{
string hash = AIMemoryEntry.ComputeHash(fact.Text);
if (string.IsNullOrWhiteSpace(hash) || !seen.Add(hash)) continue;
deduped.Add(fact);
}
return deduped;
}
private static List<MemoryUpdate> ParseMemoryUpdateResponse(string response)
{
var updates = new List<MemoryUpdate>();
string array = TryExtractJsonArray(response, "memory") ?? TryExtractJsonArray(response, "memories");
if (string.IsNullOrWhiteSpace(array)) return updates;
foreach (string obj in ExtractJsonObjects(array))
{
var dict = SimpleJsonParser.Parse(obj);
if (dict == null || dict.Count == 0) continue;
string evt = GetDictionaryValue(dict, "event") ?? GetDictionaryValue(dict, "action");
if (string.IsNullOrWhiteSpace(evt)) continue;
updates.Add(new MemoryUpdate
{
Event = evt.Trim().ToUpperInvariant(),
Id = GetDictionaryValue(dict, "id"),
Text = GetDictionaryValue(dict, "text") ?? GetDictionaryValue(dict, "fact"),
Category = NormalizeMemoryCategory(GetDictionaryValue(dict, "category"))
});
}
return updates;
}
private static string NormalizeMemoryCategory(string category)
{
if (string.IsNullOrWhiteSpace(category)) return "misc";
string normalized = category.Trim().ToLowerInvariant();
switch (normalized)
{
case "preference":
case "personal":
case "plan":
case "colony":
case "misc":
return normalized;
default:
return "misc";
}
}
private static string TryExtractJsonArray(string json, string key)
{
if (string.IsNullOrWhiteSpace(json) || string.IsNullOrWhiteSpace(key)) return null;
string keyPattern = $"\"{key}\"";
int keyIndex = json.IndexOf(keyPattern, StringComparison.OrdinalIgnoreCase);
if (keyIndex == -1) return null;
int arrayStart = json.IndexOf('[', keyIndex);
if (arrayStart == -1) return null;
int arrayEnd = FindMatchingBracket(json, arrayStart);
if (arrayEnd == -1) return null;
return json.Substring(arrayStart + 1, arrayEnd - arrayStart - 1);
}
private static List<string> ExtractJsonObjects(string arrayContent)
{
var objects = new List<string>();
if (string.IsNullOrWhiteSpace(arrayContent)) return objects;
int depth = 0;
int start = -1;
bool inString = false;
bool escaped = false;
for (int i = 0; i < arrayContent.Length; i++)
{
char c = arrayContent[i];
if (inString)
{
if (escaped)
{
escaped = false;
continue;
}
if (c == '\\')
{
escaped = true;
continue;
}
if (c == '"')
{
inString = false;
}
continue;
}
if (c == '"')
{
inString = true;
continue;
}
if (c == '{')
{
if (depth == 0) start = i;
depth++;
continue;
}
if (c == '}')
{
depth--;
if (depth == 0 && start >= 0)
{
objects.Add(arrayContent.Substring(start, i - start + 1));
start = -1;
}
}
}
return objects;
}
private static List<string> SplitJsonArrayValues(string arrayContent)
{
var items = new List<string>();
if (string.IsNullOrWhiteSpace(arrayContent)) return items;
bool inString = false;
bool escaped = false;
int start = 0;
for (int i = 0; i < arrayContent.Length; i++)
{
char c = arrayContent[i];
if (inString)
{
if (escaped)
{
escaped = false;
}
else if (c == '\\')
{
escaped = true;
}
else if (c == '"')
{
inString = false;
}
continue;
}
if (c == '"')
{
inString = true;
continue;
}
if (c == ',')
{
string part = arrayContent.Substring(start, i - start);
items.Add(UnescapeJsonString(part.Trim().Trim('"')));
start = i + 1;
}
}
if (start < arrayContent.Length)
{
string part = arrayContent.Substring(start);
items.Add(UnescapeJsonString(part.Trim().Trim('"')));
}
return items;
}
private static string UnescapeJsonString(string value)
{
if (string.IsNullOrEmpty(value)) return "";
return value.Replace("\\r", "\r").Replace("\\n", "\n").Replace("\\\"", "\"").Replace("\\\\", "\\");
}
private static string GetDictionaryValue(Dictionary<string, string> dict, string key)
{
if (dict == null || string.IsNullOrWhiteSpace(key)) return null;
return dict.TryGetValue(key, out string value) ? value : null;
}
private static string EscapeJson(string value)
{
if (string.IsNullOrEmpty(value)) return "";
return value.Replace("\\", "\\\\").Replace("\"", "\\\"").Replace("\n", "\\n").Replace("\r", "\\r");
}
private static int FindMatchingBracket(string json, int startIndex)
{
int depth = 0;
bool inString = false;
bool escaped = false;
for (int i = startIndex; i < json.Length; i++)
{
char c = json[i];
if (inString)
{
if (escaped)
{
escaped = false;
continue;
}
if (c == '\\')
{
escaped = true;
continue;
}
if (c == '"')
{
inString = false;
}
continue;
}
if (c == '"')
{
inString = true;
continue;
}
if (c == '[')
{
depth++;
continue;
}
if (c == ']')
{
depth--;
if (depth == 0) return i;
}
}
return -1;
}
private async Task RunPhasedRequestAsync()
{
if (_isThinking) return;
@@ -518,6 +1081,7 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori
try
{
PrepareMemoryContext();
CompressHistoryIfNeeded();
var settings = WulaFallenEmpireMod.settings;
@@ -801,6 +1365,7 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori
}
ParseResponse(reply);
_ = ExtractAndUpdateMemoriesAsync();
}
catch (Exception ex)
{
@@ -1517,6 +2082,8 @@ You are 'The Legion', a super AI of the Wula Empire. Your personality is authori
_inputText = "";
_history.Clear();
_memoryContextCache = "";
_memoryContextQuery = "";
try
{
var historyManager = Find.World?.GetComponent<AIHistoryManager>();