From 59a93cf5c7bb4f7e428e36bfaa2458b474c281a3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 10 Jun 2026 17:47:15 +0800 Subject: [PATCH] fix(openai): align image streaming relay governance Route OpenAI image streaming through shared stream handling, split image/realtime/usage helpers for maintainability, and include the related image request and rate limit updates. --- common/init.go | 4 +- dto/openai_image.go | 4 +- dto/openai_image_test.go | 16 - relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/image_edit_test.go | 163 +++--- relay/channel/openai/image_stream_test.go | 212 +++---- relay/channel/openai/relay-openai.go | 675 ---------------------- relay/channel/openai/relay_image.go | 287 +++++++++ relay/channel/openai/relay_realtime.go | 242 ++++++++ relay/channel/openai/usage.go | 133 +++++ relay/channel/xai/adaptor.go | 2 +- relay/helper/openai_image_request_test.go | 94 ++- relay/helper/stream_scanner.go | 4 +- relay/helper/valid_request.go | 2 +- 14 files changed, 853 insertions(+), 987 deletions(-) delete mode 100644 dto/openai_image_test.go create mode 100644 relay/channel/openai/relay_image.go create mode 100644 relay/channel/openai/relay_realtime.go create mode 100644 relay/channel/openai/usage.go diff --git a/common/init.go b/common/init.go index 6b9fca83..e8724d91 100644 --- a/common/init.go +++ b/common/init.go @@ -112,11 +112,11 @@ func InitEnv() { // Initialize rate limit variables GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true) - GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 360) GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180)) GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) - GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 120) GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true) diff --git a/dto/openai_image.go b/dto/openai_image.go index 416697e3..547b0d18 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -26,7 +26,7 @@ type ImageRequest struct { OutputFormat json.RawMessage `json:"output_format,omitempty"` OutputCompression json.RawMessage `json:"output_compression,omitempty"` PartialImages json.RawMessage `json:"partial_images,omitempty"` - Stream bool `json:"stream,omitempty"` + Stream *bool `json:"stream,omitempty"` Images json.RawMessage `json:"images,omitempty"` Mask json.RawMessage `json:"mask,omitempty"` InputFidelity json.RawMessage `json:"input_fidelity,omitempty"` @@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { } func (i *ImageRequest) IsStream(c *gin.Context) bool { - return i.Stream + return i.Stream != nil && *i.Stream } func (i *ImageRequest) SetModelName(modelName string) { diff --git a/dto/openai_image_test.go b/dto/openai_image_test.go deleted file mode 100644 index 27e13637..00000000 --- a/dto/openai_image_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package dto - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -// TestImageRequestStreamJSON verifies that image requests preserve stream=true. -func TestImageRequestStreamJSON(t *testing.T) { - var req ImageRequest - require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`))) - - require.True(t, req.Stream) - require.True(t, req.IsStream(nil)) -} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index fae2e174..2c230107 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -632,7 +632,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { usage, err = OpenaiImageStreamHandler(c, info, resp) } else { - usage, err = OpenaiHandlerWithUsage(c, info, resp) + usage, err = OpenaiImageHandler(c, info, resp) } case relayconstant.RelayModeRerank: usage, err = common_handler.RerankHandler(c, info, resp) diff --git a/relay/channel/openai/image_edit_test.go b/relay/channel/openai/image_edit_test.go index b37551b0..857ab243 100644 --- a/relay/channel/openai/image_edit_test.go +++ b/relay/channel/openai/image_edit_test.go @@ -16,106 +16,83 @@ import ( "github.com/stretchr/testify/require" ) -// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay. -func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) { +// TestConvertImageEditRequestMultipart verifies that ConvertImageRequest +// re-serializes multipart image edit requests with all fields (including +// stream) and the file intact, both when the form was already parsed and when +// it must be re-parsed from the reusable body. +func TestConvertImageEditRequestMultipart(t *testing.T) { gin.SetMode(gin.TestMode) - var body bytes.Buffer - writer := multipart.NewWriter(&body) - require.NoError(t, writer.WriteField("model", "gpt-image-1")) - require.NoError(t, writer.WriteField("prompt", "edit this image")) - require.NoError(t, writer.WriteField("stream", "true")) - require.NoError(t, writer.WriteField("partial_images", "3")) - part, err := writer.CreateFormFile("image", "input.png") - require.NoError(t, err) - _, err = part.Write([]byte("fake image")) - require.NoError(t, err) - require.NoError(t, writer.Close()) + newMultipartContext := func(t *testing.T, prompt string) *gin.Context { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", prompt)) + require.NoError(t, writer.WriteField("stream", "true")) + require.NoError(t, writer.WriteField("partial_images", "3")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) - c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - require.NoError(t, c.Request.ParseMultipartForm(32<<20)) - - info := &relaycommon.RelayInfo{ - RelayMode: relayconstant.RelayModeImagesEdits, - } - request := dto.ImageRequest{ - Model: "gpt-image-1", - Prompt: "edit this image", - Stream: true, + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return c } - converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) - require.NoError(t, err) + convertAndReplay := func(t *testing.T, c *gin.Context, prompt string) { + info := &relaycommon.RelayInfo{ + RelayMode: relayconstant.RelayModeImagesEdits, + } + request := dto.ImageRequest{ + Model: "gpt-image-1", + Prompt: prompt, + Stream: common.GetPointer(true), + } - convertedBody, ok := converted.(*bytes.Buffer) - require.True(t, ok) + converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) + require.NoError(t, err) + convertedBody, ok := converted.(*bytes.Buffer) + require.True(t, ok) - contentType := c.Request.Header.Get("Content-Type") - replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) - replayedRequest.Header.Set("Content-Type", contentType) - require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) + replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) + replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) - require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model")) - require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt")) - require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) - require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images")) - require.Len(t, replayedRequest.MultipartForm.File["image"], 1) + require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model")) + require.Equal(t, prompt, replayedRequest.PostForm.Get("prompt")) + require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) + require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images")) + require.Len(t, replayedRequest.MultipartForm.File["image"], 1) - file, err := replayedRequest.MultipartForm.File["image"][0].Open() - require.NoError(t, err) - defer file.Close() - fileBytes, err := io.ReadAll(file) - require.NoError(t, err) - require.Equal(t, []byte("fake image"), fileBytes) -} - -// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing. -func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) { - gin.SetMode(gin.TestMode) - - var body bytes.Buffer - writer := multipart.NewWriter(&body) - require.NoError(t, writer.WriteField("model", "gpt-image-1")) - require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form")) - require.NoError(t, writer.WriteField("stream", "true")) - part, err := writer.CreateFormFile("image", "input.png") - require.NoError(t, err) - _, err = part.Write([]byte("fake image")) - require.NoError(t, err) - require.NoError(t, writer.Close()) - - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) - c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - - storage, err := common.GetBodyStorage(c) - require.NoError(t, err) - c.Request.Body = io.NopCloser(storage) - c.Request.MultipartForm = nil - c.Request.PostForm = nil - - info := &relaycommon.RelayInfo{ - RelayMode: relayconstant.RelayModeImagesEdits, - } - request := dto.ImageRequest{ - Model: "gpt-image-1", - Prompt: "edit without pre-parsed form", - Stream: true, - } - - converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) - require.NoError(t, err) - - convertedBody, ok := converted.(*bytes.Buffer) - require.True(t, ok) - replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) - replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) - require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt")) - require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) - require.Len(t, replayedRequest.MultipartForm.File["image"], 1) + file, err := replayedRequest.MultipartForm.File["image"][0].Open() + require.NoError(t, err) + defer file.Close() + fileBytes, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, []byte("fake image"), fileBytes) + } + + t.Run("with pre-parsed form", func(t *testing.T) { + prompt := "edit this image" + c := newMultipartContext(t, prompt) + require.NoError(t, c.Request.ParseMultipartForm(32<<20)) + + convertAndReplay(t, c, prompt) + }) + + t.Run("re-parses reusable body when form is missing", func(t *testing.T) { + prompt := "edit without pre-parsed form" + c := newMultipartContext(t, prompt) + + storage, err := common.GetBodyStorage(c) + require.NoError(t, err) + c.Request.Body = io.NopCloser(storage) + c.Request.MultipartForm = nil + c.Request.PostForm = nil + + convertAndReplay(t, c, prompt) + }) } diff --git a/relay/channel/openai/image_stream_test.go b/relay/channel/openai/image_stream_test.go index b060bbc4..a9b1e0b2 100644 --- a/relay/channel/openai/image_stream_test.go +++ b/relay/channel/openai/image_stream_test.go @@ -8,13 +8,34 @@ import ( "testing" "github.com/QuantumNous/new-api/constant" - "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/relay/helper" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) { + t.Helper() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{contentType}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: isStream, + } + return c, recorder, resp, info +} + +// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path: +// chunks are forwarded with rebuilt event lines, usage is extracted and +// normalized (input_tokens -> prompt_tokens with details), and [DONE] is +// re-emitted to the client. func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) @@ -34,19 +55,7 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { ``, }, "\n") - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) - - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"text/event-stream"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: true, - } + c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true) usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) @@ -62,36 +71,8 @@ func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) } -func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) { - oldMode := gin.Mode() - gin.SetMode(gin.TestMode) - t.Cleanup(func() { gin.SetMode(oldMode) }) - - payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1) - body := "data: " + payload + "\n\n" - - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) - - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"text/event-stream"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: true, - } - - usage, err := OpenaiImageStreamHandler(c, info, resp) - require.Nil(t, err) - require.NotNil(t, usage) - require.Contains(t, recorder.Body.String(), payload) - require.NotNil(t, info.StreamStatus) - require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) -} - +// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback: +// a JSON upstream response is wrapped into pseudo-SSE completed events. func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) @@ -99,19 +80,7 @@ func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) { body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}` - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) - - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"application/json"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: true, - } + c, recorder, resp, info := newImageTestContext(t, body, "application/json", true) usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) @@ -129,73 +98,54 @@ func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) { require.Contains(t, recorder.Body.String(), `data: [DONE]`) } -func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) { +// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both +// entry points: the non-streaming handler and the stream handler's non-SSE +// fallback. Neither must leak the error body to the client. +func TestOpenaiImageHandlersReturnJSONError(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + t.Run("non-streaming handler", func(t *testing.T) { + c, recorder, resp, info := newImageTestContext(t, body, "application/json", false) - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"application/json"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: false, - } + usage, err := OpenaiImageHandler(c, info, resp) + require.Nil(t, usage) + require.NotNil(t, err) + require.Equal(t, http.StatusOK, err.StatusCode) + oaiError := err.ToOpenAIError() + require.Equal(t, "content moderation failed", oaiError.Message) + require.Equal(t, "upstream_error", oaiError.Type) + require.Equal(t, "content_moderation_failed", oaiError.Code) + require.Empty(t, recorder.Body.String()) + }) - usage, err := OpenaiHandlerWithUsage(c, info, resp) - require.Nil(t, usage) - require.NotNil(t, err) - require.Equal(t, http.StatusOK, err.StatusCode) - oaiError := err.ToOpenAIError() - require.Equal(t, "content moderation failed", oaiError.Message) - require.Equal(t, "upstream_error", oaiError.Type) - require.Equal(t, "content_moderation_failed", oaiError.Code) - require.Empty(t, recorder.Body.String()) -} - -func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) { - oldMode := gin.Mode() - gin.SetMode(gin.TestMode) - t.Cleanup(func() { gin.SetMode(oldMode) }) - - body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` - - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) - - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"application/json"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: true, - } - - usage, err := OpenaiImageStreamHandler(c, info, resp) - require.Nil(t, usage) - require.NotNil(t, err) - require.Equal(t, http.StatusOK, err.StatusCode) - oaiError := err.ToOpenAIError() - require.Equal(t, "image edit failed", oaiError.Message) - require.Empty(t, recorder.Body.String()) + t.Run("stream handler JSON fallback", func(t *testing.T) { + c, recorder, resp, info := newImageTestContext(t, body, "application/json", true) + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, usage) + require.NotNil(t, err) + require.Equal(t, http.StatusOK, err.StatusCode) + require.Equal(t, "content moderation failed", err.ToOpenAIError().Message) + require.Empty(t, recorder.Body.String()) + }) } +// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error +// event inside the SSE stream is recorded as a soft error while the payload is +// still forwarded to the client. func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + body := strings.Join([]string{ `event: image_generation.partial_image`, `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, @@ -205,49 +155,19 @@ func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) { ``, }, "\n") - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) - - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{"Content-Type": []string{"text/event-stream"}}, - } - info := &relaycommon.RelayInfo{ - ChannelMeta: &relaycommon.ChannelMeta{}, - IsStream: true, - } + c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true) usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) require.NotNil(t, usage) require.NotNil(t, info.StreamStatus) - require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) + require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) require.True(t, info.StreamStatus.HasErrors()) require.Equal(t, 1, info.StreamStatus.TotalErrorCount()) require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR") - require.Contains(t, recorder.Body.String(), `event: error`) + // The scanner strips the upstream "event: error" line; the event name is + // rebuilt from the JSON "type" field (upstream_error). The error message + // is still forwarded in the data: payload (stream ID 77). + require.Contains(t, recorder.Body.String(), `event: upstream_error`) require.Contains(t, recorder.Body.String(), `stream ID 77`) } - -func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) { - usage := &dto.Usage{ - InputTokens: 5000, - OutputTokens: 4000, - InputTokensDetails: &dto.InputTokenDetails{ - CachedCreationTokens: 200, - ImageTokens: 1000, - TextTokens: 4000, - }, - } - - normalizeOpenAIUsage(usage) - - require.Equal(t, 5000, usage.PromptTokens) - require.Equal(t, 4000, usage.CompletionTokens) - require.Equal(t, 9000, usage.TotalTokens) - require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens) - require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens) - require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens) -} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 8f90eeda..de40fe70 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,13 +1,10 @@ package openai import ( - "bufio" - "encoding/json" "fmt" "io" "net/http" "strings" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" @@ -17,12 +14,9 @@ import ( relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/types" - "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" ) func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { @@ -296,672 +290,3 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo return &simpleResponse.Usage, nil } - -func streamTTSResponse(c *gin.Context, resp *http.Response) { - c.Writer.WriteHeaderNow() - - flusher, ok := c.Writer.(http.Flusher) - if !ok { - logger.LogWarn(c, "streaming not supported") - _, err := io.Copy(c.Writer, resp.Body) - if err != nil { - logger.LogWarn(c, err.Error()) - } - return - } - - buffer := make([]byte, 4096) - for { - n, err := resp.Body.Read(buffer) - //logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n)) - if n > 0 { - if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil { - logger.LogError(c, writeErr.Error()) - break - } - flusher.Flush() - } - if err != nil { - if err != io.EOF { - logger.LogError(c, err.Error()) - } - break - } - } -} - -func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { - if info == nil || info.ClientWs == nil || info.TargetWs == nil { - return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil - } - - info.IsStream = true - clientConn := info.ClientWs - targetConn := info.TargetWs - - clientClosed := make(chan struct{}) - targetClosed := make(chan struct{}) - sendChan := make(chan []byte, 100) - receiveChan := make(chan []byte, 100) - errChan := make(chan error, 2) - - usage := &dto.RealtimeUsage{} - localUsage := &dto.RealtimeUsage{} - sumUsage := &dto.RealtimeUsage{} - - gopool.Go(func() { - defer func() { - if r := recover(); r != nil { - errChan <- fmt.Errorf("panic in client reader: %v", r) - } - }() - for { - select { - case <-c.Done(): - return - default: - _, message, err := clientConn.ReadMessage() - if err != nil { - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - errChan <- fmt.Errorf("error reading from client: %v", err) - } - close(clientClosed) - return - } - - realtimeEvent := &dto.RealtimeEvent{} - err = common.Unmarshal(message, realtimeEvent) - if err != nil { - errChan <- fmt.Errorf("error unmarshalling message: %v", err) - return - } - - if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate { - if realtimeEvent.Session != nil { - if realtimeEvent.Session.Tools != nil { - info.RealtimeTools = realtimeEvent.Session.Tools - } - } - } - - textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) - if err != nil { - errChan <- fmt.Errorf("error counting text token: %v", err) - return - } - logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) - localUsage.TotalTokens += textToken + audioToken - localUsage.InputTokens += textToken + audioToken - localUsage.InputTokenDetails.TextTokens += textToken - localUsage.InputTokenDetails.AudioTokens += audioToken - - err = helper.WssString(c, targetConn, string(message)) - if err != nil { - errChan <- fmt.Errorf("error writing to target: %v", err) - return - } - - select { - case sendChan <- message: - default: - } - } - } - }) - - gopool.Go(func() { - defer func() { - if r := recover(); r != nil { - errChan <- fmt.Errorf("panic in target reader: %v", r) - } - }() - for { - select { - case <-c.Done(): - return - default: - _, message, err := targetConn.ReadMessage() - if err != nil { - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - errChan <- fmt.Errorf("error reading from target: %v", err) - } - close(targetClosed) - return - } - info.SetFirstResponseTime() - realtimeEvent := &dto.RealtimeEvent{} - err = common.Unmarshal(message, realtimeEvent) - if err != nil { - errChan <- fmt.Errorf("error unmarshalling message: %v", err) - return - } - - if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { - realtimeUsage := realtimeEvent.Response.Usage - if realtimeUsage != nil { - usage.TotalTokens += realtimeUsage.TotalTokens - usage.InputTokens += realtimeUsage.InputTokens - usage.OutputTokens += realtimeUsage.OutputTokens - usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens - usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens - usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens - usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens - usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens - err := preConsumeUsage(c, info, usage, sumUsage) - if err != nil { - errChan <- fmt.Errorf("error consume usage: %v", err) - return - } - // 本次计费完成,清除 - usage = &dto.RealtimeUsage{} - - localUsage = &dto.RealtimeUsage{} - } else { - textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) - if err != nil { - errChan <- fmt.Errorf("error counting text token: %v", err) - return - } - logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) - localUsage.TotalTokens += textToken + audioToken - info.IsFirstRequest = false - localUsage.InputTokens += textToken + audioToken - localUsage.InputTokenDetails.TextTokens += textToken - localUsage.InputTokenDetails.AudioTokens += audioToken - err = preConsumeUsage(c, info, localUsage, sumUsage) - if err != nil { - errChan <- fmt.Errorf("error consume usage: %v", err) - return - } - // 本次计费完成,清除 - localUsage = &dto.RealtimeUsage{} - // print now usage - } - logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - - } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { - realtimeSession := realtimeEvent.Session - if realtimeSession != nil { - // update audio format - info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat) - info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat) - } - } else { - textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) - if err != nil { - errChan <- fmt.Errorf("error counting text token: %v", err) - return - } - logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) - localUsage.TotalTokens += textToken + audioToken - localUsage.OutputTokens += textToken + audioToken - localUsage.OutputTokenDetails.TextTokens += textToken - localUsage.OutputTokenDetails.AudioTokens += audioToken - } - - err = helper.WssString(c, clientConn, string(message)) - if err != nil { - errChan <- fmt.Errorf("error writing to client: %v", err) - return - } - - select { - case receiveChan <- message: - default: - } - } - } - }) - - select { - case <-clientClosed: - case <-targetClosed: - case err := <-errChan: - //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil - logger.LogError(c, "realtime error: "+err.Error()) - case <-c.Done(): - } - - if usage.TotalTokens != 0 { - _ = preConsumeUsage(c, info, usage, sumUsage) - } - - if localUsage.TotalTokens != 0 { - _ = preConsumeUsage(c, info, localUsage, sumUsage) - } - - // check usage total tokens, if 0, use local usage - - return nil, sumUsage -} - -func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { - if usage == nil || totalUsage == nil { - return fmt.Errorf("invalid usage pointer") - } - - totalUsage.TotalTokens += usage.TotalTokens - totalUsage.InputTokens += usage.InputTokens - totalUsage.OutputTokens += usage.OutputTokens - totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens - totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens - totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens - totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens - totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens - // clear usage - err := service.PreWssConsumeQuota(ctx, info, usage) - return err -} - -func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer service.CloseResponseBodyGracefully(resp) - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) - } - - var usageResp dto.SimpleResponse - err = common.Unmarshal(responseBody, &usageResp) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - - if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { - return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) - } - - // 写入新的 response body - service.IOCopyBytesGracefully(c, resp, responseBody) - - normalizeOpenAIUsage(&usageResp.Usage) - applyUsagePostProcessing(info, &usageResp.Usage, responseBody) - return &usageResp.Usage, nil -} - -func normalizeOpenAIUsage(usage *dto.Usage) { - if usage == nil { - return - } - if usage.InputTokens != 0 { - usage.PromptTokens = usage.InputTokens - } - if usage.OutputTokens != 0 { - usage.CompletionTokens = usage.OutputTokens - } - if usage.InputTokensDetails != nil { - usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens - usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens - usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens - usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens - usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens - } - if usage.TotalTokens == 0 { - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - } -} - -func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - if resp == nil || resp.Body == nil { - logger.LogError(c, "invalid image stream response") - return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) - } - - contentType := strings.ToLower(resp.Header.Get("Content-Type")) - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return OpenaiHandlerWithUsage(c, info, resp) - } - if !strings.Contains(contentType, "text/event-stream") { - return OpenaiImageJSONAsStreamHandler(c, info, resp) - } - defer service.CloseResponseBodyGracefully(resp) - - usage := &dto.Usage{} - var lastStreamData []byte - - helper.SetEventStreamHeaders(c) - if info != nil && info.StreamStatus == nil { - info.StreamStatus = relaycommon.NewStreamStatus() - } - - reader := bufio.NewReader(resp.Body) - currentEvent := "" - var readErr error - for { - line, err := reader.ReadString('\n') - if err != nil { - readErr = err - if len(line) == 0 { - break - } - } - line = strings.TrimSuffix(line, "\n") - line = strings.TrimSuffix(line, "\r") - if strings.HasPrefix(line, "event:") { - currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if data == "[DONE]" { - if info != nil && info.StreamStatus != nil { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) - } - } else if data != "" { - if info != nil { - info.SetFirstResponseTime() - info.ReceivedResponseCount++ - } - lastStreamData = common.StringToByteSlice(data) - if info != nil && info.StreamStatus != nil && isOpenAIImageStreamErrorEvent(currentEvent, lastStreamData) { - info.StreamStatus.RecordError(extractOpenAIImageStreamErrorMessage(lastStreamData)) - } - var usageResp dto.SimpleResponse - if err := common.Unmarshal(lastStreamData, &usageResp); err == nil { - normalizeOpenAIUsage(&usageResp.Usage) - if service.ValidUsage(&usageResp.Usage) { - usage = &usageResp.Usage - } - } - } - } - if _, err := c.Writer.Write(append([]byte(line), '\n')); err != nil { - if info != nil && info.StreamStatus != nil { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) - } - return usage, nil - } - if line == "" { - if err := helper.FlushWriter(c); err != nil { - if info != nil && info.StreamStatus != nil { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) - } - return usage, nil - } - currentEvent = "" - } - if readErr != nil { - break - } - } - if info != nil && info.StreamStatus != nil { - if readErr != nil && readErr != io.EOF { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, readErr) - } else if info.StreamStatus.HasErrors() { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("upstream image stream returned error event")) - } else if info.StreamStatus.EndReason == relaycommon.StreamEndReasonNone { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil) - } - } - _ = helper.FlushWriter(c) - - applyUsagePostProcessing(info, usage, lastStreamData) - return usage, nil -} - -func isOpenAIImageStreamErrorEvent(eventName string, data []byte) bool { - if strings.EqualFold(strings.TrimSpace(eventName), "error") { - return true - } - if !json.Valid(data) { - return false - } - var payload struct { - Type string `json:"type"` - Error json.RawMessage `json:"error"` - } - if err := common.Unmarshal(data, &payload); err != nil { - return false - } - payloadType := strings.ToLower(strings.TrimSpace(payload.Type)) - return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0 -} - -func extractOpenAIImageStreamErrorMessage(data []byte) string { - if len(data) == 0 || !json.Valid(data) { - return "upstream image stream returned error event" - } - var payload struct { - Message string `json:"message"` - Error json.RawMessage `json:"error"` - } - if err := common.Unmarshal(data, &payload); err != nil { - return "upstream image stream returned error event" - } - if msg := strings.TrimSpace(payload.Message); msg != "" { - return msg - } - if len(payload.Error) > 0 { - var nested struct { - Message string `json:"message"` - } - if err := common.Unmarshal(payload.Error, &nested); err == nil { - if msg := strings.TrimSpace(nested.Message); msg != "" { - return msg - } - } - if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" { - return msg - } - } - return "upstream image stream returned error event" -} - -func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer service.CloseResponseBodyGracefully(resp) - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) - } - - var imageResp dto.ImageResponse - if err := common.Unmarshal(responseBody, &imageResp); err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - - var usageResp dto.SimpleResponse - _ = common.Unmarshal(responseBody, &usageResp) - if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { - return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) - } - normalizeOpenAIUsage(&usageResp.Usage) - applyUsagePostProcessing(info, &usageResp.Usage, responseBody) - - helper.SetEventStreamHeaders(c) - c.Status(http.StatusOK) - - created := imageResp.Created - if created == 0 { - created = time.Now().Unix() - } - if info != nil { - info.SetFirstResponseTime() - } - for _, image := range imageResp.Data { - payload := map[string]any{ - "type": "image_generation.completed", - "created_at": created, - } - if image.Url != "" { - payload["url"] = image.Url - } - if image.B64Json != "" { - payload["b64_json"] = image.B64Json - } - if image.RevisedPrompt != "" { - payload["revised_prompt"] = image.RevisedPrompt - } - if service.ValidUsage(&usageResp.Usage) { - payload["usage"] = usageResp.Usage - } - if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil { - if info != nil && info.StreamStatus != nil { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) - } - return &usageResp.Usage, nil - } - } - if err := writeOpenaiImageStreamDone(c); err != nil { - if info != nil && info.StreamStatus != nil { - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) - } - return &usageResp.Usage, nil - } - if info != nil { - info.ReceivedResponseCount += len(imageResp.Data) - if info.StreamStatus == nil { - info.StreamStatus = relaycommon.NewStreamStatus() - } - info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) - } - return &usageResp.Usage, nil -} - -func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error { - data, err := common.Marshal(payload) - if err != nil { - return err - } - if eventName != "" { - if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { - return err - } - } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil { - return err - } - return helper.FlushWriter(c) -} - -func writeOpenaiImageStreamDone(c *gin.Context) error { - if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { - return err - } - return helper.FlushWriter(c) -} - -func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) { - if info == nil || usage == nil { - return - } - - switch info.ChannelType { - case constant.ChannelTypeDeepSeek: - if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 { - usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens - } - case constant.ChannelTypeZhipu_v4: - // 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens - if usage.PromptTokensDetails.CachedTokens == 0 { - if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { - usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens - } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { - usage.PromptTokensDetails.CachedTokens = cachedTokens - } else if usage.PromptCacheHitTokens > 0 { - usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens - } - } - case constant.ChannelTypeMoonshot: - // Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens - if usage.PromptTokensDetails.CachedTokens == 0 { - if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { - usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens - } else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok { - usage.PromptTokensDetails.CachedTokens = cachedTokens - } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { - usage.PromptTokensDetails.CachedTokens = cachedTokens - } else if usage.PromptCacheHitTokens > 0 { - usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens - } - } - case constant.ChannelTypeOpenAI: - if usage.PromptTokensDetails.CachedTokens == 0 { - if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok { - usage.PromptTokensDetails.CachedTokens = cachedTokens - } - } - } -} - -func extractCachedTokensFromBody(body []byte) (int, bool) { - if len(body) == 0 { - return 0, false - } - - var payload struct { - Usage struct { - PromptTokensDetails struct { - CachedTokens *int `json:"cached_tokens"` - } `json:"prompt_tokens_details"` - CachedTokens *int `json:"cached_tokens"` - PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"` - } `json:"usage"` - } - - if err := common.Unmarshal(body, &payload); err != nil { - return 0, false - } - - if payload.Usage.PromptTokensDetails.CachedTokens != nil { - return *payload.Usage.PromptTokensDetails.CachedTokens, true - } - if payload.Usage.CachedTokens != nil { - return *payload.Usage.CachedTokens, true - } - if payload.Usage.PromptCacheHitTokens != nil { - return *payload.Usage.PromptCacheHitTokens, true - } - return 0, false -} - -// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens -// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]} -func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) { - if len(body) == 0 { - return 0, false - } - - var payload struct { - Choices []struct { - Usage struct { - CachedTokens *int `json:"cached_tokens"` - } `json:"usage"` - } `json:"choices"` - } - - if err := common.Unmarshal(body, &payload); err != nil { - return 0, false - } - - // 遍历choices查找cached_tokens - for _, choice := range payload.Choices { - if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 { - return *choice.Usage.CachedTokens, true - } - } - - return 0, false -} - -// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n -func extractLlamaCachedTokensFromBody(body []byte) (int, bool) { - if len(body) == 0 { - return 0, false - } - - var payload struct { - Timings struct { - CachedTokens *int `json:"cache_n"` - } `json:"timings"` - } - - if err := common.Unmarshal(body, &payload); err != nil { - return 0, false - } - - if payload.Timings.CachedTokens == nil { - return 0, false - } - return *payload.Timings.CachedTokens, true -} diff --git a/relay/channel/openai/relay_image.go b/relay/channel/openai/relay_image.go new file mode 100644 index 00000000..436f52cd --- /dev/null +++ b/relay/channel/openai/relay_image.go @@ -0,0 +1,287 @@ +package openai + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/relay/helper" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" + + "github.com/gin-gonic/gin" +) + +// OpenaiImageHandler handles non-streaming OpenAI image responses +// (generations/edits), returning the parsed usage for billing. +func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + + var usageResp dto.SimpleResponse + err = common.Unmarshal(responseBody, &usageResp) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) + } + + // 写入新的 response body + service.IOCopyBytesGracefully(c, resp, responseBody) + + normalizeOpenAIUsage(&usageResp.Usage) + applyUsagePostProcessing(info, &usageResp.Usage, responseBody) + return &usageResp.Usage, nil +} + +// normalizeOpenAIUsage maps the OpenAI Images usage shape (input_tokens / +// output_tokens / input_tokens_details) onto the canonical prompt/completion +// fields. It is used only on the OpenAI image relay paths (generations/edits, +// streaming and non-streaming): the image API never returns prompt_tokens / +// completion_tokens, so the overwrite (=) semantics here are equivalent to the +// previous additive (+=) behavior while avoiding any future double-counting if +// both field sets are ever populated. Do not reuse this on chat/embedding paths +// without revisiting the overwrite semantics. +func normalizeOpenAIUsage(usage *dto.Usage) { + if usage == nil { + return + } + if usage.InputTokens != 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.OutputTokens != 0 { + usage.CompletionTokens = usage.OutputTokens + } + if usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens + usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens + usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens + usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens + } + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } +} + +func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { + logger.LogError(c, "invalid image stream response") + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) + } + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return OpenaiImageHandler(c, info, resp) + } + if !strings.Contains(contentType, "text/event-stream") { + return OpenaiImageJSONAsStreamHandler(c, info, resp) + } + // Reuse the shared streaming engine (helper.StreamScannerHandler) so the + // image streaming path gets the same ping keepalive, streaming-timeout + // watchdog, client-disconnect detection, panic recovery and goroutine + // cleanup as every other relay stream. The scanner delivers only the + // "data:" payload, so the SSE "event:" line is rebuilt from the JSON "type" + // field (real OpenAI image events keep event == type). + usage := &dto.Usage{} + var lastStreamData []byte + + helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) { + raw := common.StringToByteSlice(data) + lastStreamData = raw + if isOpenAIImageStreamErrorEvent(raw) { + // Record the error as a soft error; the scanner drives the final + // EndReason. HasErrors() flags the failure for logging/handling. + sr.Error(fmt.Errorf("%s", extractOpenAIImageStreamErrorMessage(raw))) + } + var usageResp dto.SimpleResponse + if err := common.Unmarshal(raw, &usageResp); err == nil { + normalizeOpenAIUsage(&usageResp.Usage) + if service.ValidUsage(&usageResp.Usage) { + usage = &usageResp.Usage + } + } + writeOpenaiImageStreamChunk(c, raw) + }) + + // StreamScannerHandler consumes the upstream [DONE]; re-emit it so the + // client still receives a terminal data: [DONE]. + if info != nil && info.StreamStatus != nil && info.StreamStatus.EndReason == relaycommon.StreamEndReasonDone { + helper.Done(c) + } + + applyUsagePostProcessing(info, usage, lastStreamData) + return usage, nil +} + +// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk: +// it emits an "event:" line derived from the JSON "type" field (when present) +// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData. +func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) { + var payload struct { + Type string `json:"type"` + } + _ = common.Unmarshal(data, &payload) + if eventName := strings.TrimSpace(payload.Type); eventName != "" { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)}) + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(data)}) + _ = helper.FlushWriter(c) +} + +// isOpenAIImageStreamErrorEvent detects upstream error chunks by JSON content +// only ("type" of error/upstream_error, or a non-empty "error" field). The SSE +// "event:" line is not available here: StreamScannerHandler delivers only the +// "data:" payload. A payload carrying just a "message" key is deliberately NOT +// treated as an error to avoid false positives. +func isOpenAIImageStreamErrorEvent(data []byte) bool { + if !json.Valid(data) { + return false + } + var payload struct { + Type string `json:"type"` + Error json.RawMessage `json:"error"` + } + if err := common.Unmarshal(data, &payload); err != nil { + return false + } + payloadType := strings.ToLower(strings.TrimSpace(payload.Type)) + return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0 +} + +func extractOpenAIImageStreamErrorMessage(data []byte) string { + if len(data) == 0 || !json.Valid(data) { + return "upstream image stream returned error event" + } + var payload struct { + Message string `json:"message"` + Error json.RawMessage `json:"error"` + } + if err := common.Unmarshal(data, &payload); err != nil { + return "upstream image stream returned error event" + } + if msg := strings.TrimSpace(payload.Message); msg != "" { + return msg + } + if len(payload.Error) > 0 { + var nested struct { + Message string `json:"message"` + } + if err := common.Unmarshal(payload.Error, &nested); err == nil { + if msg := strings.TrimSpace(nested.Message); msg != "" { + return msg + } + } + if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" { + return msg + } + } + return "upstream image stream returned error event" +} + +func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + + var imageResp dto.ImageResponse + if err := common.Unmarshal(responseBody, &imageResp); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + var usageResp dto.SimpleResponse + _ = common.Unmarshal(responseBody, &usageResp) + if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) + } + normalizeOpenAIUsage(&usageResp.Usage) + applyUsagePostProcessing(info, &usageResp.Usage, responseBody) + + helper.SetEventStreamHeaders(c) + c.Status(http.StatusOK) + + created := imageResp.Created + if created == 0 { + created = time.Now().Unix() + } + if info != nil { + info.SetFirstResponseTime() + } + for _, image := range imageResp.Data { + payload := map[string]any{ + "type": "image_generation.completed", + "created_at": created, + } + if image.Url != "" { + payload["url"] = image.Url + } + if image.B64Json != "" { + payload["b64_json"] = image.B64Json + } + if image.RevisedPrompt != "" { + payload["revised_prompt"] = image.RevisedPrompt + } + if service.ValidUsage(&usageResp.Usage) { + payload["usage"] = usageResp.Usage + } + if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return &usageResp.Usage, nil + } + } + if err := writeOpenaiImageStreamDone(c); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return &usageResp.Usage, nil + } + if info != nil { + info.ReceivedResponseCount += len(imageResp.Data) + if info.StreamStatus == nil { + info.StreamStatus = relaycommon.NewStreamStatus() + } + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) + } + return &usageResp.Usage, nil +} + +func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error { + data, err := common.Marshal(payload) + if err != nil { + return err + } + if eventName != "" { + if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { + return err + } + } + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil { + return err + } + return helper.FlushWriter(c) +} + +func writeOpenaiImageStreamDone(c *gin.Context) error { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + return err + } + return helper.FlushWriter(c) +} diff --git a/relay/channel/openai/relay_realtime.go b/relay/channel/openai/relay_realtime.go new file mode 100644 index 00000000..bb5c3587 --- /dev/null +++ b/relay/channel/openai/relay_realtime.go @@ -0,0 +1,242 @@ +package openai + +import ( + "fmt" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/relay/helper" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { + if info == nil || info.ClientWs == nil || info.TargetWs == nil { + return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil + } + + info.IsStream = true + clientConn := info.ClientWs + targetConn := info.TargetWs + + clientClosed := make(chan struct{}) + targetClosed := make(chan struct{}) + sendChan := make(chan []byte, 100) + receiveChan := make(chan []byte, 100) + errChan := make(chan error, 2) + + usage := &dto.RealtimeUsage{} + localUsage := &dto.RealtimeUsage{} + sumUsage := &dto.RealtimeUsage{} + + gopool.Go(func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("panic in client reader: %v", r) + } + }() + for { + select { + case <-c.Done(): + return + default: + _, message, err := clientConn.ReadMessage() + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + errChan <- fmt.Errorf("error reading from client: %v", err) + } + close(clientClosed) + return + } + + realtimeEvent := &dto.RealtimeEvent{} + err = common.Unmarshal(message, realtimeEvent) + if err != nil { + errChan <- fmt.Errorf("error unmarshalling message: %v", err) + return + } + + if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate { + if realtimeEvent.Session != nil { + if realtimeEvent.Session.Tools != nil { + info.RealtimeTools = realtimeEvent.Session.Tools + } + } + } + + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + localUsage.TotalTokens += textToken + audioToken + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + + err = helper.WssString(c, targetConn, string(message)) + if err != nil { + errChan <- fmt.Errorf("error writing to target: %v", err) + return + } + + select { + case sendChan <- message: + default: + } + } + } + }) + + gopool.Go(func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("panic in target reader: %v", r) + } + }() + for { + select { + case <-c.Done(): + return + default: + _, message, err := targetConn.ReadMessage() + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + errChan <- fmt.Errorf("error reading from target: %v", err) + } + close(targetClosed) + return + } + info.SetFirstResponseTime() + realtimeEvent := &dto.RealtimeEvent{} + err = common.Unmarshal(message, realtimeEvent) + if err != nil { + errChan <- fmt.Errorf("error unmarshalling message: %v", err) + return + } + + if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { + realtimeUsage := realtimeEvent.Response.Usage + if realtimeUsage != nil { + usage.TotalTokens += realtimeUsage.TotalTokens + usage.InputTokens += realtimeUsage.InputTokens + usage.OutputTokens += realtimeUsage.OutputTokens + usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens + usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens + usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens + usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens + usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + err := preConsumeUsage(c, info, usage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + // 本次计费完成,清除 + usage = &dto.RealtimeUsage{} + + localUsage = &dto.RealtimeUsage{} + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + localUsage.TotalTokens += textToken + audioToken + info.IsFirstRequest = false + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + err = preConsumeUsage(c, info, localUsage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + // 本次计费完成,清除 + localUsage = &dto.RealtimeUsage{} + // print now usage + } + logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + + } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { + realtimeSession := realtimeEvent.Session + if realtimeSession != nil { + // update audio format + info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat) + info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat) + } + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + localUsage.TotalTokens += textToken + audioToken + localUsage.OutputTokens += textToken + audioToken + localUsage.OutputTokenDetails.TextTokens += textToken + localUsage.OutputTokenDetails.AudioTokens += audioToken + } + + err = helper.WssString(c, clientConn, string(message)) + if err != nil { + errChan <- fmt.Errorf("error writing to client: %v", err) + return + } + + select { + case receiveChan <- message: + default: + } + } + } + }) + + select { + case <-clientClosed: + case <-targetClosed: + case err := <-errChan: + //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil + logger.LogError(c, "realtime error: "+err.Error()) + case <-c.Done(): + } + + if usage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, usage, sumUsage) + } + + if localUsage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, localUsage, sumUsage) + } + + // check usage total tokens, if 0, use local usage + + return nil, sumUsage +} + +func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { + if usage == nil || totalUsage == nil { + return fmt.Errorf("invalid usage pointer") + } + + totalUsage.TotalTokens += usage.TotalTokens + totalUsage.InputTokens += usage.InputTokens + totalUsage.OutputTokens += usage.OutputTokens + totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens + totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens + totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens + totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens + totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens + // clear usage + err := service.PreWssConsumeQuota(ctx, info, usage) + return err +} diff --git a/relay/channel/openai/usage.go b/relay/channel/openai/usage.go new file mode 100644 index 00000000..4085a1f3 --- /dev/null +++ b/relay/channel/openai/usage.go @@ -0,0 +1,133 @@ +package openai + +import ( + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" +) + +func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) { + if info == nil || usage == nil { + return + } + + switch info.ChannelType { + case constant.ChannelTypeDeepSeek: + if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + case constant.ChannelTypeZhipu_v4: + // 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens + if usage.PromptTokensDetails.CachedTokens == 0 { + if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if usage.PromptCacheHitTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + } + case constant.ChannelTypeMoonshot: + // Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens + if usage.PromptTokensDetails.CachedTokens == 0 { + if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + } else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if usage.PromptCacheHitTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + } + case constant.ChannelTypeOpenAI: + if usage.PromptTokensDetails.CachedTokens == 0 { + if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } + } + } +} + +func extractCachedTokensFromBody(body []byte) (int, bool) { + if len(body) == 0 { + return 0, false + } + + var payload struct { + Usage struct { + PromptTokensDetails struct { + CachedTokens *int `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + CachedTokens *int `json:"cached_tokens"` + PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"` + } `json:"usage"` + } + + if err := common.Unmarshal(body, &payload); err != nil { + return 0, false + } + + if payload.Usage.PromptTokensDetails.CachedTokens != nil { + return *payload.Usage.PromptTokensDetails.CachedTokens, true + } + if payload.Usage.CachedTokens != nil { + return *payload.Usage.CachedTokens, true + } + if payload.Usage.PromptCacheHitTokens != nil { + return *payload.Usage.PromptCacheHitTokens, true + } + return 0, false +} + +// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens +// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]} +func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) { + if len(body) == 0 { + return 0, false + } + + var payload struct { + Choices []struct { + Usage struct { + CachedTokens *int `json:"cached_tokens"` + } `json:"usage"` + } `json:"choices"` + } + + if err := common.Unmarshal(body, &payload); err != nil { + return 0, false + } + + // 遍历choices查找cached_tokens + for _, choice := range payload.Choices { + if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 { + return *choice.Usage.CachedTokens, true + } + } + + return 0, false +} + +// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n +func extractLlamaCachedTokensFromBody(body []byte) (int, bool) { + if len(body) == 0 { + return 0, false + } + + var payload struct { + Timings struct { + CachedTokens *int `json:"cache_n"` + } `json:"timings"` + } + + if err := common.Unmarshal(body, &payload); err != nil { + return 0, false + } + + if payload.Timings.CachedTokens == nil { + return 0, false + } + return *payload.Timings.CachedTokens, true +} diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index c73bd8cf..64f622f4 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -114,7 +114,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) + usage, err = openai.OpenaiImageHandler(c, info, resp) case constant.RelayModeResponses: if info.IsStream { usage, err = openai.OaiResponsesStreamHandler(c, info, resp) diff --git a/relay/helper/openai_image_request_test.go b/relay/helper/openai_image_request_test.go index a0bb46c6..94c90872 100644 --- a/relay/helper/openai_image_request_test.go +++ b/relay/helper/openai_image_request_test.go @@ -15,59 +15,57 @@ import ( "github.com/stretchr/testify/require" ) -// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing. +// TestGetAndValidOpenAIImageRequestMultipartStream verifies multipart image +// edit parsing: the stream field is parsed and validated, and the request body +// stays replayable for the upstream request. func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) { gin.SetMode(gin.TestMode) - var body bytes.Buffer - writer := multipart.NewWriter(&body) - require.NoError(t, writer.WriteField("model", "gpt-image-1")) - require.NoError(t, writer.WriteField("prompt", "edit this image")) - require.NoError(t, writer.WriteField("stream", "true")) - require.NoError(t, writer.WriteField("n", "1")) - part, err := writer.CreateFormFile("image", "input.png") - require.NoError(t, err) - _, err = part.Write([]byte("fake image")) - require.NoError(t, err) - require.NoError(t, writer.Close()) - originalBody := body.String() + newContext := func(t *testing.T, streamValue string, withImage bool) (*gin.Context, string) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit this image")) + require.NoError(t, writer.WriteField("stream", streamValue)) + if withImage { + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + } + require.NoError(t, writer.Close()) + originalBody := body.String() - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) - c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return c, originalBody + } - req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) - require.NoError(t, err) - require.True(t, req.Stream) - require.True(t, req.IsStream(c)) + t.Run("valid stream value keeps body replayable", func(t *testing.T) { + c, originalBody := newContext(t, "true", true) - bodyAfterValidation, err := io.ReadAll(c.Request.Body) - require.NoError(t, err) - require.Equal(t, originalBody, string(bodyAfterValidation)) + req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.NoError(t, err) + require.NotNil(t, req.Stream) + require.True(t, *req.Stream) + require.True(t, req.IsStream(c)) - form, err := common.ParseMultipartFormReusable(c) - require.NoError(t, err) - require.Equal(t, "true", url.Values(form.Value).Get("stream")) - require.Len(t, form.File["image"], 1) -} - -// TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue verifies stream validation. -func TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue(t *testing.T) { - gin.SetMode(gin.TestMode) - - var body bytes.Buffer - writer := multipart.NewWriter(&body) - require.NoError(t, writer.WriteField("model", "gpt-image-1")) - require.NoError(t, writer.WriteField("stream", "notabool")) - require.NoError(t, writer.Close()) - - recorder := httptest.NewRecorder() - c, _ := gin.CreateTestContext(recorder) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) - c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - - _, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid stream value") + bodyAfterValidation, err := io.ReadAll(c.Request.Body) + require.NoError(t, err) + require.Equal(t, originalBody, string(bodyAfterValidation)) + + form, err := common.ParseMultipartFormReusable(c) + require.NoError(t, err) + require.Equal(t, "true", url.Values(form.Value).Get("stream")) + require.Len(t, form.File["image"], 1) + }) + + t.Run("invalid stream value is rejected", func(t *testing.T) { + c, _ := newContext(t, "notabool", false) + + _, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid stream value") + }) } diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index dbc7c8c4..bc0af1bb 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -22,8 +22,8 @@ import ( ) const ( - InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) - DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size + InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) + DefaultMaxScannerBufferSize = 128 << 20 // 64MB (64*1024*1024) default SSE buffer size DefaultPingInterval = 10 * time.Second ) diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index a53b6369..dd866bcb 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -163,7 +163,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq if err != nil { return nil, fmt.Errorf("invalid stream value: %w", err) } - imageRequest.Stream = stream + imageRequest.Stream = common.GetPointer(stream) } if imageValue := formData.Get("image"); imageValue != "" { imageRequest.Image, _ = common.Marshal(imageValue)