新增原生工具调用数据结构与解析:SimpleAIClient.cs

AITool 增加 Schema 构造器与函数定义生成,所有工具补齐 GetParametersSchema():AITool.cs 与 *.cs
This commit is contained in:
2025-12-31 15:44:57 +08:00
parent 244ba3d354
commit 1e64302d21
21 changed files with 936 additions and 4 deletions

View File

@@ -7,9 +7,51 @@ using UnityEngine;
using Verse;
using System.Linq;
using System.Text.RegularExpressions;
using WulaFallenEmpire.EventSystem.AI.Utils;
namespace WulaFallenEmpire.EventSystem.AI
{
public sealed class ToolCallRequest
{
public string Id;
public string Name;
public string ArgumentsJson;
}
public sealed class ChatMessage
{
public string Role;
public string Content;
public string ToolCallId;
public List<ToolCallRequest> ToolCalls;
public static ChatMessage User(string content)
{
return new ChatMessage { Role = "user", Content = content };
}
public static ChatMessage Assistant(string content)
{
return new ChatMessage { Role = "assistant", Content = content };
}
public static ChatMessage AssistantWithToolCalls(List<ToolCallRequest> toolCalls, string content = null)
{
return new ChatMessage { Role = "assistant", Content = content, ToolCalls = toolCalls };
}
public static ChatMessage ToolResult(string toolCallId, string content)
{
return new ChatMessage { Role = "tool", ToolCallId = toolCallId, Content = content };
}
}
public sealed class ChatCompletionResult
{
public string Content;
public List<ToolCallRequest> ToolCalls;
}
public class SimpleAIClient
{
private readonly string _apiKey;
@@ -125,6 +167,31 @@ namespace WulaFallenEmpire.EventSystem.AI
return response;
}
public async Task<ChatCompletionResult> GetChatCompletionWithToolsAsync(string instruction, List<ChatMessage> messages, List<Dictionary<string, object>> tools, int? maxTokens = null, float? temperature = null)
{
if (_useGemini)
{
WulaLog.Debug("[WulaAI] Native tool calling is not supported with Gemini protocol.");
return null;
}
if (string.IsNullOrEmpty(_baseUrl))
{
WulaLog.Debug("[WulaAI] Base URL is missing.");
return null;
}
string endpoint = $"{_baseUrl}/chat/completions";
if (_baseUrl.EndsWith("/chat/completions")) endpoint = _baseUrl;
else if (!_baseUrl.EndsWith("/v1")) endpoint = $"{_baseUrl}/v1/chat/completions";
string jsonBody = BuildChatRequestBody(instruction, messages, tools, maxTokens, temperature);
string response = await SendRequestRawAsync(endpoint, jsonBody, _apiKey);
if (response == null) return null;
return ExtractChatCompletionResult(response);
}
private async Task<string> GetGeminiCompletionAsync(string instruction, List<(string role, string message)> messages, int? maxTokens = null, float? temperature = null, string base64Image = null)
{
// Ensure messages is not empty to avoid Gemini 400 Error (Invalid Argument)
@@ -185,6 +252,13 @@ namespace WulaFallenEmpire.EventSystem.AI
}
private async Task<string> SendRequestAsync(string endpoint, string jsonBody, string apiKey)
{
string response = await SendRequestRawAsync(endpoint, jsonBody, apiKey);
if (response == null) return null;
return ExtractContent(response);
}
private async Task<string> SendRequestRawAsync(string endpoint, string jsonBody, string apiKey)
{
if (Prefs.DevMode)
{
@@ -228,10 +302,251 @@ namespace WulaFallenEmpire.EventSystem.AI
{
WulaLog.Debug($"[WulaAI] Response Body:\n{TruncateForLog(response)}");
}
return ExtractContent(response);
return response;
}
}
private string BuildChatRequestBody(string instruction, List<ChatMessage> messages, List<Dictionary<string, object>> tools, int? maxTokens, float? temperature)
{
var body = new Dictionary<string, object>
{
["model"] = _model,
["stream"] = false
};
if (maxTokens.HasValue) body["max_tokens"] = Math.Max(1, maxTokens.Value);
if (temperature.HasValue) body["temperature"] = temperature.Value;
var messageList = new List<object>();
if (!string.IsNullOrEmpty(instruction))
{
messageList.Add(new Dictionary<string, object>
{
["role"] = "system",
["content"] = instruction
});
}
if (messages != null)
{
foreach (var msg in messages)
{
if (msg == null) continue;
string role = string.IsNullOrWhiteSpace(msg.Role) ? "user" : msg.Role;
var entry = new Dictionary<string, object>
{
["role"] = role
};
if (string.Equals(role, "tool", StringComparison.OrdinalIgnoreCase))
{
entry["tool_call_id"] = msg.ToolCallId ?? "";
entry["content"] = msg.Content ?? "";
messageList.Add(entry);
continue;
}
if (string.Equals(role, "assistant", StringComparison.OrdinalIgnoreCase) && msg.ToolCalls != null && msg.ToolCalls.Count > 0)
{
var toolCalls = new List<object>();
foreach (var call in msg.ToolCalls)
{
if (call == null) continue;
var callEntry = new Dictionary<string, object>
{
["type"] = "function",
["function"] = new Dictionary<string, object>
{
["name"] = call.Name ?? "",
["arguments"] = call.ArgumentsJson ?? "{}"
}
};
if (!string.IsNullOrWhiteSpace(call.Id))
{
callEntry["id"] = call.Id;
}
toolCalls.Add(callEntry);
}
entry["content"] = string.IsNullOrWhiteSpace(msg.Content) ? null : msg.Content;
entry["tool_calls"] = toolCalls;
messageList.Add(entry);
continue;
}
entry["content"] = msg.Content ?? "";
messageList.Add(entry);
}
}
body["messages"] = messageList;
if (tools != null && tools.Count > 0)
{
var toolList = new List<object>();
foreach (var tool in tools)
{
if (tool == null) continue;
toolList.Add(tool);
}
if (toolList.Count > 0)
{
body["tools"] = toolList;
}
}
return JsonToolCallParser.SerializeToJson(body);
}
private ChatCompletionResult ExtractChatCompletionResult(string json)
{
if (string.IsNullOrWhiteSpace(json)) return null;
if (!JsonToolCallParser.TryParseObject(json, out var root))
{
return new ChatCompletionResult { Content = ExtractContent(json) };
}
if (!TryGetList(root, "choices", out var choices) || choices.Count == 0)
{
return new ChatCompletionResult { Content = ExtractContent(json) };
}
var firstChoice = choices[0] as Dictionary<string, object>;
if (firstChoice == null)
{
return new ChatCompletionResult { Content = ExtractContent(json) };
}
Dictionary<string, object> message = null;
if (TryGetObject(firstChoice, "message", out var msgObj))
{
message = msgObj;
}
else if (TryGetObject(firstChoice, "delta", out var deltaObj))
{
message = deltaObj;
}
if (message == null)
{
return new ChatCompletionResult { Content = ExtractContent(json) };
}
string content = TryGetString(message, "content");
var result = new ChatCompletionResult
{
Content = content,
ToolCalls = ParseToolCalls(message)
};
return result;
}
private static List<ToolCallRequest> ParseToolCalls(Dictionary<string, object> message)
{
if (!TryGetList(message, "tool_calls", out var calls) || calls.Count == 0)
{
return null;
}
var results = new List<ToolCallRequest>();
foreach (var callObj in calls)
{
if (callObj is not Dictionary<string, object> callDict) continue;
string id = TryGetString(callDict, "id");
string name = null;
object argsObj = null;
if (TryGetObject(callDict, "function", out var fnObj))
{
name = TryGetString(fnObj, "name");
TryGetValue(fnObj, "arguments", out argsObj);
}
else
{
name = TryGetString(callDict, "name");
TryGetValue(callDict, "arguments", out argsObj);
}
if (string.IsNullOrWhiteSpace(name)) continue;
string argsJson = "{}";
if (argsObj is string argsString)
{
argsJson = string.IsNullOrWhiteSpace(argsString) ? "{}" : argsString;
}
else if (argsObj is Dictionary<string, object> argsDict)
{
argsJson = JsonToolCallParser.SerializeToJson(argsDict);
}
else if (argsObj != null)
{
argsJson = JsonToolCallParser.SerializeToJson(argsObj);
}
results.Add(new ToolCallRequest
{
Id = id,
Name = name,
ArgumentsJson = argsJson
});
}
return results.Count > 0 ? results : null;
}
private static bool TryGetList(Dictionary<string, object> obj, string key, out List<object> list)
{
list = null;
if (!TryGetValue(obj, key, out object raw)) return false;
if (raw is List<object> rawList)
{
list = rawList;
return true;
}
return false;
}
private static bool TryGetObject(Dictionary<string, object> obj, string key, out Dictionary<string, object> value)
{
value = null;
if (!TryGetValue(obj, key, out object raw)) return false;
if (raw is Dictionary<string, object> dict)
{
value = dict;
return true;
}
return false;
}
private static string TryGetString(Dictionary<string, object> obj, string key)
{
if (TryGetValue(obj, key, out object value) && value != null)
{
return Convert.ToString(value);
}
return null;
}
private static bool TryGetValue(Dictionary<string, object> obj, string key, out object value)
{
if (obj == null)
{
value = null;
return false;
}
foreach (var kvp in obj)
{
if (string.Equals(kvp.Key, key, StringComparison.OrdinalIgnoreCase))
{
value = kvp.Value;
return true;
}
}
value = null;
return false;
}
private string ExtractContent(string json)
{
if (string.IsNullOrWhiteSpace(json)) return null;