Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: correcting some errors in 'rewriting logic'
Browse files Browse the repository at this point in the history
pepesi committed Jan 24, 2025
1 parent 21df731 commit 506dcaf
Showing 27 changed files with 110 additions and 70 deletions.
33 changes: 21 additions & 12 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf

rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
@@ -103,20 +103,24 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
if err != nil {
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
return types.ActionContinue
}
ctx.DontReadRequestBody()
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}

util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}

@@ -151,6 +155,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
@@ -267,7 +275,8 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
}
}

func getOpenAiApiName(path string) provider.ApiName {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
@@ -280,7 +289,7 @@ func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// rerank
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ func (m *ai360Provider) GetProviderType() string {

func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
@@ -59,7 +59,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam

func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
@@ -64,15 +64,15 @@ func (m *azureProvider) GetProviderType() string {

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
@@ -51,15 +51,15 @@ func (m *baichuanProvider) GetProviderType() string {

func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
@@ -52,15 +52,15 @@ func (g *baiduProvider) GetProviderType() string {

func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
8 changes: 6 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
@@ -112,7 +112,7 @@ func (c *claudeProvider) GetProviderType() string {

func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
@@ -133,7 +133,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
@@ -169,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// only process the response from chat completion, skip other responses
if name != ApiNameChatCompletion {
return chunk, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Original file line number Diff line number Diff line change
@@ -50,15 +50,15 @@ func (c *cloudflareProvider) GetProviderType() string {

func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !c.config.isSupportedAPI(apiName) {
return c.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return nil
}

func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, c.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
@@ -67,15 +67,15 @@ func (m *cohereProvider) GetProviderType() string {

func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
@@ -84,7 +84,7 @@ func (d *deeplProvider) GetProviderType() string {

func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !d.config.isSupportedAPI(apiName) {
return d.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
@@ -97,7 +97,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !d.config.isSupportedAPI(apiName) {
return types.ActionContinue, d.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
@@ -119,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
}

func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
@@ -53,15 +53,15 @@ func (m *deepseekProvider) GetProviderType() string {

func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/dify.go
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ func (d *difyProvider) GetProviderType() string {

func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return d.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
return nil
@@ -78,7 +78,7 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName

func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, d.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
@@ -99,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
}

func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
difyResponse := &DifyChatResponse{}
if err := json.Unmarshal(body, difyResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal dify response: %v", err)
@@ -150,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample event response:
// data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"}

4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
Original file line number Diff line number Diff line change
@@ -51,15 +51,15 @@ func (m *doubaoProvider) GetProviderType() string {

func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ func (g *geminiProvider) GetProviderType() string {

func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
@@ -72,7 +72,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam

func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
@@ -115,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/github.go
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ func (m *githubProvider) GetProviderType() string {

func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
@@ -62,7 +62,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa

func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/groq.go
Original file line number Diff line number Diff line change
@@ -50,15 +50,15 @@ func (g *groqProvider) GetProviderType() string {

func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !g.config.isSupportedAPI(apiName) {
return g.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
return nil
}

func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, g.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
9 changes: 6 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Original file line number Diff line number Diff line change
@@ -137,7 +137,7 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {

func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
@@ -161,7 +161,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
if m.useOpenAICompatibleAPI() {
return types.ActionContinue, nil
@@ -321,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
}

func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion {
return chunk, nil
}

@@ -440,6 +440,9 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
return body, nil
}
if apiName != ApiNameChatCompletion {
return body, nil
}
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ func (m *minimaxProvider) GetProviderType() string {

func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
@@ -90,7 +90,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa

func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
if minimaxApiTypePro == m.config.minimaxApiType {
// Use chat completion Pro API.
@@ -167,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
// Sample event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}

@@ -200,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name

// TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
minimaxResp := &minimaxChatCompletionProResp{}
if err := json.Unmarshal(body, minimaxResp); err != nil {
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
Original file line number Diff line number Diff line change
@@ -49,15 +49,15 @@ func (m *mistralProvider) GetProviderType() string {

func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
7 changes: 5 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@ func (m *moonshotProvider) GetProviderType() string {

func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
@@ -82,7 +82,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
// 非chat类型的请求,不做处理
if apiName != ApiNameChatCompletion {
@@ -165,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
}

func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if name != ApiNameChatCompletion {
return chunk, nil
}
receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...)
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
Original file line number Diff line number Diff line change
@@ -55,15 +55,15 @@ func (m *ollamaProvider) GetProviderType() string {

func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
9 changes: 3 additions & 6 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
@@ -267,7 +267,7 @@ type ProviderConfig struct {
capabilities map[string]string
// @Title zh-CN 是否开启透传
// @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false
passthrsough bool
passthrough bool
}

func (c *ProviderConfig) GetId() string {
@@ -453,11 +453,8 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}

func (c *ProviderConfig) handleUnsupportedAPI() error {
if c.passthrsough {
return nil
}
return errUnsupportedApiName
func (c *ProviderConfig) PassthroughUnsupportedAPI() bool {
return c.passthrough
}

func CreateProvider(pc ProviderConfig) (Provider, error) {
6 changes: 3 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ func (m *qwenProvider) GetProviderType() string {

func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}

m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -150,7 +150,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}

if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
@@ -290,7 +290,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
if m.config.isSupportedAPI(apiName) {
return body, nil
}
return nil, m.config.handleUnsupportedAPI()
return nil, errUnsupportedApiName
}

func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
10 changes: 8 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/spark.go
Original file line number Diff line number Diff line change
@@ -75,20 +75,23 @@ func (p *sparkProvider) GetProviderType() string {

func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !p.config.isSupportedAPI(apiName) {
return p.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
p.config.handleRequestHeaders(p, ctx, apiName, log)
return nil
}

func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !p.config.isSupportedAPI(apiName) {
return types.ActionContinue, p.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
}

func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
@@ -104,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
Original file line number Diff line number Diff line change
@@ -50,15 +50,15 @@ func (m *stepfunProvider) GetProviderType() string {

func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go
Original file line number Diff line number Diff line change
@@ -49,15 +49,15 @@ func (m *togetherAIProvider) GetProviderType() string {

func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/yi.go
Original file line number Diff line number Diff line change
@@ -49,15 +49,15 @@ func (m *yiProvider) GetProviderType() string {

func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
Original file line number Diff line number Diff line change
@@ -51,15 +51,15 @@ func (m *zhipuAiProvider) GetProviderType() string {

func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if !m.config.isSupportedAPI(apiName) {
return m.config.handleUnsupportedAPI()
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}

func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, m.config.handleUnsupportedAPI()
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

0 comments on commit 506dcaf

Please sign in to comment.