easyai-ai-gateway/apps/api/internal/runner/param_processor_script.go

197 lines
6.5 KiB
Go

package runner
import (
"context"
"fmt"
"reflect"
"strings"
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Service) preprocessRequestWithScripts(ctx context.Context, kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult {
if platformConfigBool(candidate.PlatformConfig, "skipParamNormalization", "skip_param_normalization") {
modelType := strings.TrimSpace(candidate.ModelType)
if modelType == "" {
modelType = modelTypeFromKind(kind, body)
}
input := cloneMap(body)
return parameterPreprocessResult{
Body: cloneMap(body),
Log: parameterPreprocessingLog{
ModelType: modelType,
Input: input,
Output: cloneMap(body),
Changed: false,
Changes: []parameterPreprocessChange{},
Model: preprocessingModelSnapshot(candidate),
},
}
}
result := preprocessRequestWithLog(kind, body, candidate)
if result.Err != nil {
return result
}
scriptText := platformConfigString(candidate.PlatformConfig, "customPreprocessScript", "custom_preprocess_script")
if strings.TrimSpace(scriptText) == "" || s.scriptExecutor == nil {
return result
}
before := cloneMap(result.Body)
scriptContext := s.scriptContext(candidate, result.Log.ModelType, nil, map[string]any{
"modelCapability": effectiveModelCapability(candidate),
"platformModel": result.Log.Model,
"platform": candidate.PlatformConfig,
})
out, err := s.scriptExecutor.Execute(ctx, scriptengine.Options{
Script: scriptText,
Args: []any{cloneMap(result.Body), result.Log.ModelType, scriptContext},
ContextData: scriptContext,
ScriptName: "custom_preprocess_script:" + result.Log.ModelType,
PreferredEntryNames: []string{"preprocessParams", "preprocess", "main", "handler"},
Timeout: scriptengine.PreprocessTimeout,
})
if err != nil {
result.Log.recordScriptChange("CustomPreprocessScript", "error", "$", before, result.Body, err.Error())
result.Log.Output = cloneMap(result.Body)
result.Log.Changed = len(result.Log.Changes) > 0
result.Err = err
return result
}
rewritten, ok := out.(map[string]any)
if !ok || rewritten == nil {
result.Log.Output = cloneMap(result.Body)
result.Log.Changed = len(result.Log.Changes) > 0
return result
}
merged := cloneMap(result.Body)
for key, value := range rewritten {
merged[key] = value
}
if !mapsEqual(before, merged) {
result.Log.recordScriptChange("CustomPreprocessScript", "rewrite", "$", before, merged, "platform custom preprocess script returned parameter updates")
}
result.Body = merged
result.Log.Output = cloneMap(merged)
result.Log.Changed = len(result.Log.Changes) > 0
return result
}
func (s *Service) scriptContext(candidate store.RuntimeModelCandidate, modelType string, payload map[string]any, extra map[string]any) map[string]any {
getTaskURL := platformConfigString(candidate.PlatformConfig, "getTaskURL", "get_task_url")
baseURL := strings.TrimRight(strings.TrimSpace(candidate.BaseURL), "/")
env := cloneMap(candidate.PlatformConfig)
context := map[string]any{
"__easyaiScriptContext": true,
"baseURL": baseURL,
"getTaskURL": getTaskURL,
"authValues": cloneMap(candidate.Credentials),
"headers": map[string]any{},
"payload": cloneMap(payload),
"type": modelType,
"options": map[string]any{
"model": candidate.ModelName,
"providerModelName": candidate.ProviderModelName,
"platformId": candidate.PlatformID,
"platformModelId": candidate.PlatformModelID,
"canonicalModelKey": candidate.CanonicalModelKey,
"sourceProviderCode": candidate.Provider,
},
"env": env,
"candidate": preprocessingModelSnapshot(candidate),
}
context["createRequestURL"] = func(path string, base ...string) string {
selectedBase := baseURL
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
}
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
return selectedBase + "/" + strings.TrimLeft(path, "/")
}
context["creatRequestURL"] = context["createRequestURL"]
context["resolveGetTaskURL"] = func(taskID string) string {
return resolveTaskURLTemplate(getTaskURL, taskID, "")
}
for key, value := range extra {
context[key] = value
}
return context
}
func preprocessingModelSnapshot(candidate store.RuntimeModelCandidate) map[string]any {
return map[string]any{
"modelName": candidate.ModelName,
"modelAlias": candidate.ModelAlias,
"providerModelName": candidate.ProviderModelName,
"provider": candidate.Provider,
"platformId": candidate.PlatformID,
"platformModelId": candidate.PlatformModelID,
"capabilities": cloneMap(candidate.Capabilities),
}
}
func (log *parameterPreprocessingLog) recordScriptChange(processor string, action string, path string, before any, after any, reason string) {
if log == nil {
return
}
log.Changes = append(log.Changes, parameterPreprocessChange{
Processor: processor,
Action: action,
Path: path,
Before: cloneAny(before),
After: cloneAny(after),
Reason: reason,
})
}
func platformConfigString(config map[string]any, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
return value
}
}
return ""
}
func platformConfigBool(config map[string]any, keys ...string) bool {
for _, key := range keys {
switch value := config[key].(type) {
case bool:
return value
case string:
return strings.EqualFold(strings.TrimSpace(value), "true")
}
}
return false
}
func resolveTaskURLTemplate(template string, upstreamTaskID string, taskID string) string {
out := strings.TrimSpace(template)
replacements := [][2]string{
{"${upstream_task_id}", upstreamTaskID},
{"{{upstream_task_id}}", upstreamTaskID},
{"{upstream_task_id}", upstreamTaskID},
{"${task_id}", taskID},
{"{{task_id}}", taskID},
{"{task_id}", taskID},
{"${taskId}", upstreamTaskID},
{"${taskID}", upstreamTaskID},
{"{{taskId}}", upstreamTaskID},
{"{{taskID}}", upstreamTaskID},
{"{taskId}", upstreamTaskID},
{"{taskID}", upstreamTaskID},
}
for _, replacement := range replacements {
out = strings.ReplaceAll(out, replacement[0], replacement[1])
}
return out
}
func mapsEqual(left map[string]any, right map[string]any) bool {
return reflect.DeepEqual(left, right)
}