perf: optimize request metadata extraction and disabled field filtering (#5009)
* perf: optimize request metadata extraction and disabled field filtering * perf: optimize stream usage estimation path
This commit is contained in:
@@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
@@ -170,6 +172,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
// - application/x-www-form-urlencoded
|
// - application/x-www-form-urlencoded
|
||||||
// - multipart/form-data
|
// - multipart/form-data
|
||||||
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||||
|
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
|
||||||
|
modelRequest, err := getModelFromJSONBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
|
||||||
|
}
|
||||||
|
return modelRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -178,6 +188,50 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
|||||||
return &modelRequest, nil
|
return &modelRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
|
||||||
|
storage, err := common.GetBodyStorage(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
requestBody, err := storage.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(requestBody) {
|
||||||
|
return nil, errors.New("invalid JSON request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
values := gjson.GetManyBytes(requestBody, "model", "group")
|
||||||
|
model, err := getJSONStringValue(values[0], "model")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
group, err := getJSONStringValue(values[1], "group")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||||
|
return nil, seekErr
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(storage)
|
||||||
|
|
||||||
|
return &ModelRequest{
|
||||||
|
Model: model,
|
||||||
|
Group: group,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getJSONStringValue(result gjson.Result, field string) (string, error) {
|
||||||
|
if !result.Exists() || result.Type == gjson.Null {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if result.Type != gjson.String {
|
||||||
|
return "", fmt.Errorf("field %s must be a string", field)
|
||||||
|
}
|
||||||
|
return result.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
shouldSelectChannel := true
|
shouldSelectChannel := true
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
@@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
||||||
|
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
|
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
for _, choice := range streamResponse.Choices {
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
// 一次性解析失败,逐个解析
|
|
||||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
|
||||||
for _, item := range streamItems {
|
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
|
||||||
common.SysLog("error processing stream response: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量处理所有响应
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
|
||||||
if choice.Delta.ToolCalls != nil {
|
|
||||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
|
||||||
*toolCount = len(choice.Delta.ToolCalls)
|
|
||||||
}
|
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
|
||||||
var streamResponses []dto.CompletionsStreamResponse
|
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
|
||||||
// 一次性解析失败,逐个解析
|
|
||||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
|
||||||
for _, item := range streamItems {
|
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量处理所有响应
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||||
|
|||||||
@@ -119,7 +119,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
var toolCount int
|
var toolCount int
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var streamItems []string // store stream items
|
|
||||||
var lastStreamData string
|
var lastStreamData string
|
||||||
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
|
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
|
||||||
|
|
||||||
@@ -140,7 +139,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
lastStreamData = data
|
lastStreamData = data
|
||||||
streamItems = append(streamItems, data)
|
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
|
||||||
|
logger.LogError(c, "error processing stream token data: "+err.Error())
|
||||||
|
sr.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -175,11 +177,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理token计算
|
|
||||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
|
||||||
logger.LogError(c, "error processing tokens: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
|||||||
@@ -2054,6 +2054,17 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
|||||||
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
|
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) {
|
||||||
|
input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
settings := dto.ChannelOtherSettings{}
|
||||||
|
|
||||||
|
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||||
|
}
|
||||||
|
require.Equal(t, input, string(out))
|
||||||
|
}
|
||||||
|
|
||||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||||
input := `{
|
input := `{
|
||||||
"inference_geo":"eu",
|
"inference_geo":"eu",
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ThinkingContentInfo struct {
|
type ThinkingContentInfo struct {
|
||||||
@@ -785,6 +786,9 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
|||||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||||
return jsonData, nil
|
return jsonData, nil
|
||||||
}
|
}
|
||||||
|
if !hasRemovableDisabledField(jsonData, channelOtherSettings) {
|
||||||
|
return jsonData, nil
|
||||||
|
}
|
||||||
|
|
||||||
var data map[string]interface{}
|
var data map[string]interface{}
|
||||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||||
@@ -851,6 +855,25 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
|||||||
return jsonDataAfter, nil
|
return jsonDataAfter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool {
|
||||||
|
values := gjson.GetManyBytes(
|
||||||
|
jsonData,
|
||||||
|
"service_tier",
|
||||||
|
"inference_geo",
|
||||||
|
"speed",
|
||||||
|
"store",
|
||||||
|
"safety_identifier",
|
||||||
|
"stream_options.include_obfuscation",
|
||||||
|
)
|
||||||
|
|
||||||
|
return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) ||
|
||||||
|
(!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) ||
|
||||||
|
(!channelOtherSettings.AllowSpeed && values[2].Exists()) ||
|
||||||
|
(channelOtherSettings.DisableStore && values[3].Exists()) ||
|
||||||
|
(!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) ||
|
||||||
|
(!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists())
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
|
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
|
||||||
// Currently supports removing functionResponse.id field which Vertex AI does not support
|
// Currently supports removing functionResponse.id field which Vertex AI does not support
|
||||||
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
|
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user