perf: optimize request metadata extraction and disabled field filtering (#5009)

* perf: optimize request metadata extraction and disabled field filtering

* perf: optimize stream usage estimation path
This commit is contained in:
Seefs
2026-05-22 10:32:11 +08:00
committed by GitHub
parent 006e801652
commit ae6a03364d
5 changed files with 106 additions and 72 deletions
+14 -65
View File
@@ -1,7 +1,6 @@
package openai
import (
"encoding/json"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
var streamResponse dto.CompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
return err
}
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysLog("error processing stream response: " + err.Error())
}
}
return nil
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
+4 -7
View File
@@ -119,7 +119,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var lastStreamData string
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
@@ -140,7 +139,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
lastStreamData = data
streamItems = append(streamItems, data)
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing stream token data: "+err.Error())
sr.Error(err)
}
}
})
@@ -175,11 +177,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
usage.CompletionTokens += toolCount * 7
+11
View File
@@ -2054,6 +2054,17 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
}
func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) {
input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
settings := dto.ChannelOtherSettings{}
out, err := RemoveDisabledFields([]byte(input), settings, false)
if err != nil {
t.Fatalf("RemoveDisabledFields returned error: %v", err)
}
require.Equal(t, input, string(out))
}
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
input := `{
"inference_geo":"eu",
+23
View File
@@ -18,6 +18,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
)
type ThinkingContentInfo struct {
@@ -785,6 +786,9 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
return jsonData, nil
}
if !hasRemovableDisabledField(jsonData, channelOtherSettings) {
return jsonData, nil
}
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
@@ -851,6 +855,25 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
return jsonDataAfter, nil
}
func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool {
values := gjson.GetManyBytes(
jsonData,
"service_tier",
"inference_geo",
"speed",
"store",
"safety_identifier",
"stream_options.include_obfuscation",
)
return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) ||
(!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) ||
(!channelOtherSettings.AllowSpeed && values[2].Exists()) ||
(channelOtherSettings.DisableStore && values[3].Exists()) ||
(!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) ||
(!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists())
}
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
// Currently supports removing functionResponse.id field which Vertex AI does not support
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {