197 lines
6.5 KiB
Go
197 lines
6.5 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func (s *Service) preprocessRequestWithScripts(ctx context.Context, kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult {
|
|
if platformConfigBool(candidate.PlatformConfig, "skipParamNormalization", "skip_param_normalization") {
|
|
modelType := strings.TrimSpace(candidate.ModelType)
|
|
if modelType == "" {
|
|
modelType = modelTypeFromKind(kind, body)
|
|
}
|
|
input := cloneMap(body)
|
|
return parameterPreprocessResult{
|
|
Body: cloneMap(body),
|
|
Log: parameterPreprocessingLog{
|
|
ModelType: modelType,
|
|
Input: input,
|
|
Output: cloneMap(body),
|
|
Changed: false,
|
|
Changes: []parameterPreprocessChange{},
|
|
Model: preprocessingModelSnapshot(candidate),
|
|
},
|
|
}
|
|
}
|
|
|
|
result := preprocessRequestWithLog(kind, body, candidate)
|
|
if result.Err != nil {
|
|
return result
|
|
}
|
|
scriptText := platformConfigString(candidate.PlatformConfig, "customPreprocessScript", "custom_preprocess_script")
|
|
if strings.TrimSpace(scriptText) == "" || s.scriptExecutor == nil {
|
|
return result
|
|
}
|
|
|
|
before := cloneMap(result.Body)
|
|
scriptContext := s.scriptContext(candidate, result.Log.ModelType, nil, map[string]any{
|
|
"modelCapability": effectiveModelCapability(candidate),
|
|
"platformModel": result.Log.Model,
|
|
"platform": candidate.PlatformConfig,
|
|
})
|
|
out, err := s.scriptExecutor.Execute(ctx, scriptengine.Options{
|
|
Script: scriptText,
|
|
Args: []any{cloneMap(result.Body), result.Log.ModelType, scriptContext},
|
|
ContextData: scriptContext,
|
|
ScriptName: "custom_preprocess_script:" + result.Log.ModelType,
|
|
PreferredEntryNames: []string{"preprocessParams", "preprocess", "main", "handler"},
|
|
Timeout: scriptengine.PreprocessTimeout,
|
|
})
|
|
if err != nil {
|
|
result.Log.recordScriptChange("CustomPreprocessScript", "error", "$", before, result.Body, err.Error())
|
|
result.Log.Output = cloneMap(result.Body)
|
|
result.Log.Changed = len(result.Log.Changes) > 0
|
|
result.Err = err
|
|
return result
|
|
}
|
|
rewritten, ok := out.(map[string]any)
|
|
if !ok || rewritten == nil {
|
|
result.Log.Output = cloneMap(result.Body)
|
|
result.Log.Changed = len(result.Log.Changes) > 0
|
|
return result
|
|
}
|
|
merged := cloneMap(result.Body)
|
|
for key, value := range rewritten {
|
|
merged[key] = value
|
|
}
|
|
if !mapsEqual(before, merged) {
|
|
result.Log.recordScriptChange("CustomPreprocessScript", "rewrite", "$", before, merged, "platform custom preprocess script returned parameter updates")
|
|
}
|
|
result.Body = merged
|
|
result.Log.Output = cloneMap(merged)
|
|
result.Log.Changed = len(result.Log.Changes) > 0
|
|
return result
|
|
}
|
|
|
|
func (s *Service) scriptContext(candidate store.RuntimeModelCandidate, modelType string, payload map[string]any, extra map[string]any) map[string]any {
|
|
getTaskURL := platformConfigString(candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
|
baseURL := strings.TrimRight(strings.TrimSpace(candidate.BaseURL), "/")
|
|
env := cloneMap(candidate.PlatformConfig)
|
|
context := map[string]any{
|
|
"__easyaiScriptContext": true,
|
|
"baseURL": baseURL,
|
|
"getTaskURL": getTaskURL,
|
|
"authValues": cloneMap(candidate.Credentials),
|
|
"headers": map[string]any{},
|
|
"payload": cloneMap(payload),
|
|
"type": modelType,
|
|
"options": map[string]any{
|
|
"model": candidate.ModelName,
|
|
"providerModelName": candidate.ProviderModelName,
|
|
"platformId": candidate.PlatformID,
|
|
"platformModelId": candidate.PlatformModelID,
|
|
"canonicalModelKey": candidate.CanonicalModelKey,
|
|
"sourceProviderCode": candidate.Provider,
|
|
},
|
|
"env": env,
|
|
"candidate": preprocessingModelSnapshot(candidate),
|
|
}
|
|
context["createRequestURL"] = func(path string, base ...string) string {
|
|
selectedBase := baseURL
|
|
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
|
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
|
}
|
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
|
return path
|
|
}
|
|
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
|
}
|
|
context["creatRequestURL"] = context["createRequestURL"]
|
|
context["resolveGetTaskURL"] = func(taskID string) string {
|
|
return resolveTaskURLTemplate(getTaskURL, taskID, "")
|
|
}
|
|
for key, value := range extra {
|
|
context[key] = value
|
|
}
|
|
return context
|
|
}
|
|
|
|
func preprocessingModelSnapshot(candidate store.RuntimeModelCandidate) map[string]any {
|
|
return map[string]any{
|
|
"modelName": candidate.ModelName,
|
|
"modelAlias": candidate.ModelAlias,
|
|
"providerModelName": candidate.ProviderModelName,
|
|
"provider": candidate.Provider,
|
|
"platformId": candidate.PlatformID,
|
|
"platformModelId": candidate.PlatformModelID,
|
|
"capabilities": cloneMap(candidate.Capabilities),
|
|
}
|
|
}
|
|
|
|
func (log *parameterPreprocessingLog) recordScriptChange(processor string, action string, path string, before any, after any, reason string) {
|
|
if log == nil {
|
|
return
|
|
}
|
|
log.Changes = append(log.Changes, parameterPreprocessChange{
|
|
Processor: processor,
|
|
Action: action,
|
|
Path: path,
|
|
Before: cloneAny(before),
|
|
After: cloneAny(after),
|
|
Reason: reason,
|
|
})
|
|
}
|
|
|
|
func platformConfigString(config map[string]any, keys ...string) string {
|
|
for _, key := range keys {
|
|
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func platformConfigBool(config map[string]any, keys ...string) bool {
|
|
for _, key := range keys {
|
|
switch value := config[key].(type) {
|
|
case bool:
|
|
return value
|
|
case string:
|
|
return strings.EqualFold(strings.TrimSpace(value), "true")
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func resolveTaskURLTemplate(template string, upstreamTaskID string, taskID string) string {
|
|
out := strings.TrimSpace(template)
|
|
replacements := [][2]string{
|
|
{"${upstream_task_id}", upstreamTaskID},
|
|
{"{{upstream_task_id}}", upstreamTaskID},
|
|
{"{upstream_task_id}", upstreamTaskID},
|
|
{"${task_id}", taskID},
|
|
{"{{task_id}}", taskID},
|
|
{"{task_id}", taskID},
|
|
{"${taskId}", upstreamTaskID},
|
|
{"${taskID}", upstreamTaskID},
|
|
{"{{taskId}}", upstreamTaskID},
|
|
{"{{taskID}}", upstreamTaskID},
|
|
{"{taskId}", upstreamTaskID},
|
|
{"{taskID}", upstreamTaskID},
|
|
}
|
|
for _, replacement := range replacements {
|
|
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
|
}
|
|
return out
|
|
}
|
|
|
|
func mapsEqual(left map[string]any, right map[string]any) bool {
|
|
return reflect.DeepEqual(left, right)
|
|
}
|