using System; using System.Collections.Generic; using System.Text; using System.Threading.Tasks; using UnityEngine.Networking; using UnityEngine; using Verse; using System.Linq; namespace WulaFallenEmpire.EventSystem.AI { public class SimpleAIClient { private readonly string _apiKey; private readonly string _baseUrl; private readonly string _model; private readonly bool _useGemini; private const int MaxLogChars = 2000; public SimpleAIClient(string apiKey, string baseUrl, string model, bool useGemini = false) { _apiKey = apiKey; _baseUrl = baseUrl?.TrimEnd('/'); _model = model; _useGemini = useGemini; } public async Task GetChatCompletionAsync(string instruction, List<(string role, string message)> messages, int? maxTokens = null, float? temperature = null, string base64Image = null) { if (_useGemini) { return await GetGeminiCompletionAsync(instruction, messages, maxTokens, temperature, base64Image); } // OpenAI / Compatible Mode if (string.IsNullOrEmpty(_baseUrl)) { WulaLog.Debug("[WulaAI] Base URL is missing."); return null; } string endpoint = $"{_baseUrl}/chat/completions"; if (_baseUrl.EndsWith("/chat/completions")) endpoint = _baseUrl; else if (!_baseUrl.EndsWith("/v1")) endpoint = $"{_baseUrl}/v1/chat/completions"; StringBuilder jsonBuilder = new StringBuilder(); jsonBuilder.Append("{"); jsonBuilder.Append($"\"model\": \"{_model}\","); jsonBuilder.Append("\"stream\": false,"); if (maxTokens.HasValue) jsonBuilder.Append($"\"max_tokens\": {Math.Max(1, maxTokens.Value)},"); if (temperature.HasValue) jsonBuilder.Append($"\"temperature\": {temperature.Value.ToString("0.###", System.Globalization.CultureInfo.InvariantCulture)},"); jsonBuilder.Append("\"messages\": ["); if (!string.IsNullOrEmpty(instruction)) { jsonBuilder.Append($"{{\"role\": \"system\", \"content\": \"{EscapeJson(instruction)}\"}},"); } for (int i = 0; i < messages.Count; i++) { var msg = messages[i]; string role = (msg.role ?? "user").ToLowerInvariant(); if (role == "ai" || role == "assistant") role = "assistant"; else if (role == "tool") role = "system"; else if (role == "toolcall") continue; jsonBuilder.Append($"{{\"role\": \"{role}\", "); if (i == messages.Count - 1 && role == "user" && !string.IsNullOrEmpty(base64Image)) { jsonBuilder.Append("\"content\": ["); jsonBuilder.Append($"{{\"type\": \"text\", \"text\": \"{EscapeJson(msg.message)}\"}},"); jsonBuilder.Append($"{{\"type\": \"image_url\", \"image_url\": {{\"url\": \"data:image/png;base64,{base64Image}\"}}}}"); jsonBuilder.Append("]"); } else { jsonBuilder.Append($"\"content\": \"{EscapeJson(msg.message)}\""); } jsonBuilder.Append("}"); if (i < messages.Count - 1) jsonBuilder.Append(","); } jsonBuilder.Append("]}"); return await SendRequestAsync(endpoint, jsonBuilder.ToString(), _apiKey); } private async Task GetGeminiCompletionAsync(string instruction, List<(string role, string message)> messages, int? maxTokens = null, float? temperature = null, string base64Image = null) { // Gemini API URL string baseUrl = _baseUrl; if (string.IsNullOrEmpty(baseUrl) || !baseUrl.Contains("googleapis.com")) { baseUrl = "https://generativelanguage.googleapis.com/v1beta"; } string endpoint = $"{baseUrl}/models/{_model}:generateContent?key={_apiKey}"; StringBuilder jsonBuilder = new StringBuilder(); jsonBuilder.Append("{"); if (!string.IsNullOrEmpty(instruction)) { jsonBuilder.Append("\"system_instruction\": {\"parts\": [{\"text\": \"" + EscapeJson(instruction) + "\"}]},"); } jsonBuilder.Append("\"contents\": ["); for (int i = 0; i < messages.Count; i++) { var msg = messages[i]; string role = (msg.role ?? "user").ToLowerInvariant(); if (role == "assistant" || role == "ai") role = "model"; else role = "user"; jsonBuilder.Append($"{{\"role\": \"{role}\", \"parts\": ["); jsonBuilder.Append($"{{\"text\": \"{EscapeJson(msg.message)}\"}}"); if (i == messages.Count - 1 && role == "user" && !string.IsNullOrEmpty(base64Image)) { jsonBuilder.Append($", {{\"inline_data\": {{\"mime_type\": \"image/png\", \"data\": \"{base64Image}\"}}}}"); } jsonBuilder.Append("]}"); if (i < messages.Count - 1) jsonBuilder.Append(","); } jsonBuilder.Append("],"); jsonBuilder.Append("\"generationConfig\": {"); if (temperature.HasValue) jsonBuilder.Append($"\"temperature\": {temperature.Value.ToString("0.###", System.Globalization.CultureInfo.InvariantCulture)},"); if (maxTokens.HasValue) jsonBuilder.Append($"\"maxOutputTokens\": {maxTokens.Value}"); else jsonBuilder.Append("\"maxOutputTokens\": 2048"); jsonBuilder.Append("}"); jsonBuilder.Append("}"); return await SendRequestAsync(endpoint, jsonBuilder.ToString(), null); } private async Task SendRequestAsync(string endpoint, string jsonBody, string apiKey) { if (Prefs.DevMode) { WulaLog.Debug($"[WulaAI] Sending request to {endpoint}"); } using (UnityWebRequest request = new UnityWebRequest(endpoint, "POST")) { byte[] bodyRaw = Encoding.UTF8.GetBytes(jsonBody); request.uploadHandler = new UploadHandlerRaw(bodyRaw); request.downloadHandler = new DownloadHandlerBuffer(); request.SetRequestHeader("Content-Type", "application/json"); if (!string.IsNullOrEmpty(apiKey)) { request.SetRequestHeader("Authorization", $"Bearer {apiKey}"); } request.timeout = 60; var operation = request.SendWebRequest(); while (!operation.isDone) await Task.Delay(50); if (request.result != UnityWebRequest.Result.Success) { string errText = request.downloadHandler.text; WulaLog.Debug($"[WulaAI] API Error ({request.responseCode}): {request.error}\nResponse: {TruncateForLog(errText)}"); return null; } string response = request.downloadHandler.text; return ExtractContent(response); } } private string ExtractContent(string json) { if (string.IsNullOrWhiteSpace(json)) return null; try { // 1. Gemini format if (json.Contains("\"candidates\"")) { int partsIndex = json.IndexOf("\"parts\"", StringComparison.Ordinal); if (partsIndex != -1) return ExtractJsonValue(json, "text", partsIndex); } // 2. OpenAI format if (json.Contains("\"choices\"")) { int choicesIndex = json.IndexOf("\"choices\"", StringComparison.Ordinal); string firstChoice = TryExtractFirstChoiceObject(json, choicesIndex); if (!string.IsNullOrEmpty(firstChoice)) { int messageIndex = firstChoice.IndexOf("\"message\"", StringComparison.Ordinal); if (messageIndex != -1) return ExtractJsonValue(firstChoice, "content", messageIndex); int deltaIndex = firstChoice.IndexOf("\"delta\"", StringComparison.Ordinal); if (deltaIndex != -1) return ExtractJsonValue(firstChoice, "content", deltaIndex); return ExtractJsonValue(firstChoice, "text", 0); } } // 3. Last fallback return ExtractJsonValue(json, "content"); } catch (Exception ex) { WulaLog.Debug($"[WulaAI] Error parsing response: {ex.Message}"); } return null; } private static string TryExtractFirstChoiceObject(string json, int choicesKeyIndex) { int arrayStart = json.IndexOf('[', choicesKeyIndex); if (arrayStart == -1) return null; int objStart = json.IndexOf('{', arrayStart); if (objStart == -1) return null; int objEnd = FindMatchingBrace(json, objStart); if (objEnd == -1) return null; return json.Substring(objStart, objEnd - objStart + 1); } private static int FindMatchingBrace(string json, int startIndex) { int depth = 0; bool inString = false; bool escaped = false; for (int i = startIndex; i < json.Length; i++) { char c = json[i]; if (inString) { if (escaped) { escaped = false; continue; } if (c == '\\') { escaped = true; continue; } if (c == '"') inString = false; continue; } if (c == '"') { inString = true; continue; } if (c == '{') depth++; if (c == '}') { depth--; if (depth == 0) return i; } } return -1; } private static string ExtractJsonValue(string json, string key, int startIndex = 0) { string keyPattern = $"\"{key}\""; int keyIndex = json.IndexOf(keyPattern, startIndex, StringComparison.Ordinal); if (keyIndex == -1) return null; int colonIndex = json.IndexOf(':', keyIndex + keyPattern.Length); if (colonIndex == -1) return null; int valueStart = json.IndexOf('"', colonIndex); if (valueStart == -1) return null; StringBuilder sb = new StringBuilder(); bool escaped = false; for (int i = valueStart + 1; i < json.Length; i++) { char c = json[i]; if (escaped) { if (c == 'n') sb.Append('\n'); else if (c == 'r') sb.Append('\r'); else if (c == 't') sb.Append('\t'); else if (c == '"') sb.Append('"'); else if (c == '\\') sb.Append('\\'); else sb.Append(c); escaped = false; } else { if (c == '\\') escaped = true; else if (c == '"') return sb.ToString(); else sb.Append(c); } } return null; } private string EscapeJson(string s) { if (s == null) return ""; StringBuilder sb = new StringBuilder(s.Length + 16); for (int i = 0; i < s.Length; i++) { char c = s[i]; switch (c) { case '\\': sb.Append("\\\\"); break; case '"': sb.Append("\\\""); break; case '\n': sb.Append("\\n"); break; case '\r': sb.Append("\\r"); break; case '\t': sb.Append("\\t"); break; default: if (c < 0x20) { sb.Append("\\u"); sb.Append(((int)c).ToString("x4")); } else sb.Append(c); break; } } return sb.ToString(); } private static string TruncateForLog(string s) { if (string.IsNullOrEmpty(s)) return s; if (s.Length <= MaxLogChars) return s; return s.Substring(0, MaxLogChars) + "... (truncated)"; } } }