From d2576ddcd31ff752c30b54d1781e802e4021f824 Mon Sep 17 00:00:00 2001 From: gaoren002 <83566620+gaoren002@users.noreply.github.com> Date: Mon, 8 Jun 2026 18:36:17 +0800 Subject: [PATCH] fix(openai): support streaming image relay and image edit for images API (#4608) * fix(openai): support streaming image relay * fix(openai): keep image edit multipart body reusable * test(openai): cover image stream usage details * test(openai): cover image edit fallback stream field * fix(openai): wrap image json fallback as stream * fix(relay): support OpenAI image streaming * fix(openai): record image stream upstream error events * fix(openai): harden image stream relay * fix(openai): return image JSON errors * fix(relay): reset stream status per scanner run * fix(relay): drop upstream credit passthrough * fix(openai): keep image errors minimal * fix(openai): keep image error status from response --------- Co-authored-by: CaIon --- dto/openai_image.go | 12 +- dto/openai_image_test.go | 16 ++ relay/channel/openai/adaptor.go | 16 +- relay/channel/openai/image_edit_test.go | 121 ++++++++++ relay/channel/openai/image_stream_test.go | 253 +++++++++++++++++++ relay/channel/openai/relay-openai.go | 282 ++++++++++++++++++++-- relay/helper/openai_image_request_test.go | 73 ++++++ relay/helper/stream_scanner_test.go | 4 +- relay/helper/valid_request.go | 15 +- 9 files changed, 764 insertions(+), 28 deletions(-) create mode 100644 dto/openai_image_test.go create mode 100644 relay/channel/openai/image_edit_test.go create mode 100644 relay/channel/openai/image_stream_test.go create mode 100644 relay/helper/openai_image_request_test.go diff --git a/dto/openai_image.go b/dto/openai_image.go index fdef12b1..416697e3 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -26,11 +26,11 @@ type ImageRequest struct { OutputFormat json.RawMessage `json:"output_format,omitempty"` OutputCompression json.RawMessage `json:"output_compression,omitempty"` PartialImages json.RawMessage `json:"partial_images,omitempty"` - // Stream bool `json:"stream,omitempty"` - Images json.RawMessage `json:"images,omitempty"` - Mask json.RawMessage `json:"mask,omitempty"` - InputFidelity json.RawMessage `json:"input_fidelity,omitempty"` - Watermark *bool `json:"watermark,omitempty"` + Stream bool `json:"stream,omitempty"` + Images json.RawMessage `json:"images,omitempty"` + Mask json.RawMessage `json:"mask,omitempty"` + InputFidelity json.RawMessage `json:"input_fidelity,omitempty"` + Watermark *bool `json:"watermark,omitempty"` // zhipu 4v WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"` UserId json.RawMessage `json:"user_id,omitempty"` @@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { } func (i *ImageRequest) IsStream(c *gin.Context) bool { - return false + return i.Stream } func (i *ImageRequest) SetModelName(modelName string) { diff --git a/dto/openai_image_test.go b/dto/openai_image_test.go new file mode 100644 index 00000000..27e13637 --- /dev/null +++ b/dto/openai_image_test.go @@ -0,0 +1,16 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestImageRequestStreamJSON verifies that image requests preserve stream=true. +func TestImageRequestStreamJSON(t *testing.T) { + var req ImageRequest + require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`))) + + require.True(t, req.Stream) + require.True(t, req.IsStream(nil)) +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 26fac8e7..fae2e174 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + "net/url" "path/filepath" "strings" @@ -439,10 +440,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf // 使用已解析的 multipart 表单,避免重复解析 mf := c.Request.MultipartForm if mf == nil { - if _, err := c.MultipartForm(); err != nil { - return nil, errors.New("failed to parse multipart form") + form, err := common.ParseMultipartFormReusable(c) + if err != nil { + return nil, fmt.Errorf("failed to parse multipart form: %w", err) } - mf = c.Request.MultipartForm + c.Request.MultipartForm = form + c.Request.PostForm = url.Values(form.Value) + mf = form } // 写入所有非文件字段 @@ -625,7 +629,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case relayconstant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - usage, err = OpenaiHandlerWithUsage(c, info, resp) + if info.IsStream { + usage, err = OpenaiImageStreamHandler(c, info, resp) + } else { + usage, err = OpenaiHandlerWithUsage(c, info, resp) + } case relayconstant.RelayModeRerank: usage, err = common_handler.RerankHandler(c, info, resp) case relayconstant.RelayModeResponses: diff --git a/relay/channel/openai/image_edit_test.go b/relay/channel/openai/image_edit_test.go new file mode 100644 index 00000000..b37551b0 --- /dev/null +++ b/relay/channel/openai/image_edit_test.go @@ -0,0 +1,121 @@ +package openai + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay. +func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit this image")) + require.NoError(t, writer.WriteField("stream", "true")) + require.NoError(t, writer.WriteField("partial_images", "3")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + require.NoError(t, c.Request.ParseMultipartForm(32<<20)) + + info := &relaycommon.RelayInfo{ + RelayMode: relayconstant.RelayModeImagesEdits, + } + request := dto.ImageRequest{ + Model: "gpt-image-1", + Prompt: "edit this image", + Stream: true, + } + + converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) + require.NoError(t, err) + + convertedBody, ok := converted.(*bytes.Buffer) + require.True(t, ok) + + contentType := c.Request.Header.Get("Content-Type") + replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) + replayedRequest.Header.Set("Content-Type", contentType) + require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) + + require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model")) + require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt")) + require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) + require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images")) + require.Len(t, replayedRequest.MultipartForm.File["image"], 1) + + file, err := replayedRequest.MultipartForm.File["image"][0].Open() + require.NoError(t, err) + defer file.Close() + fileBytes, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, []byte("fake image"), fileBytes) +} + +// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing. +func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form")) + require.NoError(t, writer.WriteField("stream", "true")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + storage, err := common.GetBodyStorage(c) + require.NoError(t, err) + c.Request.Body = io.NopCloser(storage) + c.Request.MultipartForm = nil + c.Request.PostForm = nil + + info := &relaycommon.RelayInfo{ + RelayMode: relayconstant.RelayModeImagesEdits, + } + request := dto.ImageRequest{ + Model: "gpt-image-1", + Prompt: "edit without pre-parsed form", + Stream: true, + } + + converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request) + require.NoError(t, err) + + convertedBody, ok := converted.(*bytes.Buffer) + require.True(t, ok) + replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes())) + replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + require.NoError(t, replayedRequest.ParseMultipartForm(32<<20)) + require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt")) + require.Equal(t, "true", replayedRequest.PostForm.Get("stream")) + require.Len(t, replayedRequest.MultipartForm.File["image"], 1) +} diff --git a/relay/channel/openai/image_stream_test.go b/relay/channel/openai/image_stream_test.go new file mode 100644 index 00000000..b060bbc4 --- /dev/null +++ b/relay/channel/openai/image_stream_test.go @@ -0,0 +1,253 @@ +package openai + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/relay/helper" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + + body := strings.Join([]string{ + `event: image_generation.partial_image`, + `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, + ``, + `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, err) + require.Equal(t, 3, usage.PromptTokens) + require.Equal(t, 4, usage.CompletionTokens) + require.Equal(t, 7, usage.TotalTokens) + require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens) + require.Equal(t, 1, usage.PromptTokensDetails.TextTokens) + require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`) + require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`) + require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`) + require.Contains(t, recorder.Body.String(), `data: [DONE]`) + require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) +} + +func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1) + body := "data: " + payload + "\n\n" + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, err) + require.NotNil(t, usage) + require.Contains(t, recorder.Body.String(), payload) + require.NotNil(t, info.StreamStatus) + require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) +} + +func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}` + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, err) + require.Equal(t, 3, usage.PromptTokens) + require.Equal(t, 4, usage.CompletionTokens) + require.Equal(t, 7, usage.TotalTokens) + require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens) + require.Equal(t, 1, usage.PromptTokensDetails.TextTokens) + require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) + require.Empty(t, recorder.Header().Get("Content-Length")) + require.Contains(t, recorder.Body.String(), `event: image_generation.completed`) + require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`) + require.Contains(t, recorder.Body.String(), `"b64_json":"final"`) + require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`) + require.Contains(t, recorder.Body.String(), `data: [DONE]`) +} + +func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: false, + } + + usage, err := OpenaiHandlerWithUsage(c, info, resp) + require.Nil(t, usage) + require.NotNil(t, err) + require.Equal(t, http.StatusOK, err.StatusCode) + oaiError := err.ToOpenAIError() + require.Equal(t, "content moderation failed", oaiError.Message) + require.Equal(t, "upstream_error", oaiError.Type) + require.Equal(t, "content_moderation_failed", oaiError.Code) + require.Empty(t, recorder.Body.String()) +} + +func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, usage) + require.NotNil(t, err) + require.Equal(t, http.StatusOK, err.StatusCode) + oaiError := err.ToOpenAIError() + require.Equal(t, "image edit failed", oaiError.Message) + require.Empty(t, recorder.Body.String()) +} + +func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) { + oldMode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(oldMode) }) + + body := strings.Join([]string{ + `event: image_generation.partial_image`, + `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, + ``, + `event: error`, + `data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`, + ``, + }, "\n") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + IsStream: true, + } + + usage, err := OpenaiImageStreamHandler(c, info, resp) + require.Nil(t, err) + require.NotNil(t, usage) + require.NotNil(t, info.StreamStatus) + require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) + require.True(t, info.StreamStatus.HasErrors()) + require.Equal(t, 1, info.StreamStatus.TotalErrorCount()) + require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR") + require.Contains(t, recorder.Body.String(), `event: error`) + require.Contains(t, recorder.Body.String(), `stream ID 77`) +} + +func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) { + usage := &dto.Usage{ + InputTokens: 5000, + OutputTokens: 4000, + InputTokensDetails: &dto.InputTokenDetails{ + CachedCreationTokens: 200, + ImageTokens: 1000, + TextTokens: 4000, + }, + } + + normalizeOpenAIUsage(usage) + + require.Equal(t, 5000, usage.PromptTokens) + require.Equal(t, 4000, usage.CompletionTokens) + require.Equal(t, 9000, usage.TotalTokens) + require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens) + require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens) + require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d6a354f7..8f90eeda 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,10 +1,13 @@ package openai import ( + "bufio" + "encoding/json" "fmt" "io" "net/http" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" @@ -566,27 +569,278 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) + } + // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) - // Once we've written to the client, we should not return errors anymore - // because the upstream has already consumed resources and returned content - // We should still perform billing even if parsing fails - // format - if usageResp.InputTokens > 0 { - usageResp.PromptTokens += usageResp.InputTokens - } - if usageResp.OutputTokens > 0 { - usageResp.CompletionTokens += usageResp.OutputTokens - } - if usageResp.InputTokensDetails != nil { - usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens - usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens - } + normalizeOpenAIUsage(&usageResp.Usage) applyUsagePostProcessing(info, &usageResp.Usage, responseBody) return &usageResp.Usage, nil } +func normalizeOpenAIUsage(usage *dto.Usage) { + if usage == nil { + return + } + if usage.InputTokens != 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.OutputTokens != 0 { + usage.CompletionTokens = usage.OutputTokens + } + if usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens + usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens + usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens + usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens + } + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } +} + +func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { + logger.LogError(c, "invalid image stream response") + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) + } + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return OpenaiHandlerWithUsage(c, info, resp) + } + if !strings.Contains(contentType, "text/event-stream") { + return OpenaiImageJSONAsStreamHandler(c, info, resp) + } + defer service.CloseResponseBodyGracefully(resp) + + usage := &dto.Usage{} + var lastStreamData []byte + + helper.SetEventStreamHeaders(c) + if info != nil && info.StreamStatus == nil { + info.StreamStatus = relaycommon.NewStreamStatus() + } + + reader := bufio.NewReader(resp.Body) + currentEvent := "" + var readErr error + for { + line, err := reader.ReadString('\n') + if err != nil { + readErr = err + if len(line) == 0 { + break + } + } + line = strings.TrimSuffix(line, "\n") + line = strings.TrimSuffix(line, "\r") + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "[DONE]" { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) + } + } else if data != "" { + if info != nil { + info.SetFirstResponseTime() + info.ReceivedResponseCount++ + } + lastStreamData = common.StringToByteSlice(data) + if info != nil && info.StreamStatus != nil && isOpenAIImageStreamErrorEvent(currentEvent, lastStreamData) { + info.StreamStatus.RecordError(extractOpenAIImageStreamErrorMessage(lastStreamData)) + } + var usageResp dto.SimpleResponse + if err := common.Unmarshal(lastStreamData, &usageResp); err == nil { + normalizeOpenAIUsage(&usageResp.Usage) + if service.ValidUsage(&usageResp.Usage) { + usage = &usageResp.Usage + } + } + } + } + if _, err := c.Writer.Write(append([]byte(line), '\n')); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return usage, nil + } + if line == "" { + if err := helper.FlushWriter(c); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return usage, nil + } + currentEvent = "" + } + if readErr != nil { + break + } + } + if info != nil && info.StreamStatus != nil { + if readErr != nil && readErr != io.EOF { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, readErr) + } else if info.StreamStatus.HasErrors() { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("upstream image stream returned error event")) + } else if info.StreamStatus.EndReason == relaycommon.StreamEndReasonNone { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil) + } + } + _ = helper.FlushWriter(c) + + applyUsagePostProcessing(info, usage, lastStreamData) + return usage, nil +} + +func isOpenAIImageStreamErrorEvent(eventName string, data []byte) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "error") { + return true + } + if !json.Valid(data) { + return false + } + var payload struct { + Type string `json:"type"` + Error json.RawMessage `json:"error"` + } + if err := common.Unmarshal(data, &payload); err != nil { + return false + } + payloadType := strings.ToLower(strings.TrimSpace(payload.Type)) + return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0 +} + +func extractOpenAIImageStreamErrorMessage(data []byte) string { + if len(data) == 0 || !json.Valid(data) { + return "upstream image stream returned error event" + } + var payload struct { + Message string `json:"message"` + Error json.RawMessage `json:"error"` + } + if err := common.Unmarshal(data, &payload); err != nil { + return "upstream image stream returned error event" + } + if msg := strings.TrimSpace(payload.Message); msg != "" { + return msg + } + if len(payload.Error) > 0 { + var nested struct { + Message string `json:"message"` + } + if err := common.Unmarshal(payload.Error, &nested); err == nil { + if msg := strings.TrimSpace(nested.Message); msg != "" { + return msg + } + } + if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" { + return msg + } + } + return "upstream image stream returned error event" +} + +func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + + var imageResp dto.ImageResponse + if err := common.Unmarshal(responseBody, &imageResp); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + var usageResp dto.SimpleResponse + _ = common.Unmarshal(responseBody, &usageResp) + if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) + } + normalizeOpenAIUsage(&usageResp.Usage) + applyUsagePostProcessing(info, &usageResp.Usage, responseBody) + + helper.SetEventStreamHeaders(c) + c.Status(http.StatusOK) + + created := imageResp.Created + if created == 0 { + created = time.Now().Unix() + } + if info != nil { + info.SetFirstResponseTime() + } + for _, image := range imageResp.Data { + payload := map[string]any{ + "type": "image_generation.completed", + "created_at": created, + } + if image.Url != "" { + payload["url"] = image.Url + } + if image.B64Json != "" { + payload["b64_json"] = image.B64Json + } + if image.RevisedPrompt != "" { + payload["revised_prompt"] = image.RevisedPrompt + } + if service.ValidUsage(&usageResp.Usage) { + payload["usage"] = usageResp.Usage + } + if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return &usageResp.Usage, nil + } + } + if err := writeOpenaiImageStreamDone(c); err != nil { + if info != nil && info.StreamStatus != nil { + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err) + } + return &usageResp.Usage, nil + } + if info != nil { + info.ReceivedResponseCount += len(imageResp.Data) + if info.StreamStatus == nil { + info.StreamStatus = relaycommon.NewStreamStatus() + } + info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil) + } + return &usageResp.Usage, nil +} + +func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error { + data, err := common.Marshal(payload) + if err != nil { + return err + } + if eventName != "" { + if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { + return err + } + } + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil { + return err + } + return helper.FlushWriter(c) +} + +func writeOpenaiImageStreamDone(c *gin.Context) error { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + return err + } + return helper.FlushWriter(c) +} + func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) { if info == nil || usage == nil { return diff --git a/relay/helper/openai_image_request_test.go b/relay/helper/openai_image_request_test.go new file mode 100644 index 00000000..a0bb46c6 --- /dev/null +++ b/relay/helper/openai_image_request_test.go @@ -0,0 +1,73 @@ +package helper + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/QuantumNous/new-api/common" + relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing. +func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("prompt", "edit this image")) + require.NoError(t, writer.WriteField("stream", "true")) + require.NoError(t, writer.WriteField("n", "1")) + part, err := writer.CreateFormFile("image", "input.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake image")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + originalBody := body.String() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.NoError(t, err) + require.True(t, req.Stream) + require.True(t, req.IsStream(c)) + + bodyAfterValidation, err := io.ReadAll(c.Request.Body) + require.NoError(t, err) + require.Equal(t, originalBody, string(bodyAfterValidation)) + + form, err := common.ParseMultipartFormReusable(c) + require.NoError(t, err) + require.Equal(t, "true", url.Values(form.Value).Get("stream")) + require.Len(t, form.File["image"], 1) +} + +// TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue verifies stream validation. +func TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-1")) + require.NoError(t, writer.WriteField("stream", "notabool")) + require.NoError(t, writer.Close()) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body) + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + + _, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid stream value") +} diff --git a/relay/helper/stream_scanner_test.go b/relay/helper/stream_scanner_test.go index d1577266..904cd6d1 100644 --- a/relay/helper/stream_scanner_test.go +++ b/relay/helper/stream_scanner_test.go @@ -631,7 +631,7 @@ func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) { assert.NotNil(t, info.StreamStatus) } -func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) { +func TestStreamScannerHandler_StreamStatus_ReplacesPreInitialized(t *testing.T) { t.Parallel() body := buildSSEBody(5) @@ -643,7 +643,7 @@ func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) - assert.Equal(t, 1, info.StreamStatus.TotalErrorCount()) + assert.Equal(t, 0, info.StreamStatus.TotalErrorCount()) } func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 2581b281..a53b6369 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "math" + "net/url" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -144,16 +146,25 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq switch relayMode { case relayconstant.RelayModeImagesEdits: if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { - _, err := c.MultipartForm() + form, err := common.ParseMultipartFormReusable(c) if err != nil { return nil, fmt.Errorf("failed to parse image edit form request: %w", err) } - formData := c.Request.PostForm + formData := url.Values(form.Value) + c.Request.MultipartForm = form + c.Request.PostForm = formData imageRequest.Prompt = formData.Get("prompt") imageRequest.Model = formData.Get("model") imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n")))) imageRequest.Quality = formData.Get("quality") imageRequest.Size = formData.Get("size") + if streamValue := strings.TrimSpace(formData.Get("stream")); streamValue != "" { + stream, err := strconv.ParseBool(streamValue) + if err != nil { + return nil, fmt.Errorf("invalid stream value: %w", err) + } + imageRequest.Stream = stream + } if imageValue := formData.Get("image"); imageValue != "" { imageRequest.Image, _ = common.Marshal(imageValue) }