fix(openai): support streaming image relay and image edit for images API (#4608)
* fix(openai): support streaming image relay * fix(openai): keep image edit multipart body reusable * test(openai): cover image stream usage details * test(openai): cover image edit fallback stream field * fix(openai): wrap image json fallback as stream * fix(relay): support OpenAI image streaming * fix(openai): record image stream upstream error events * fix(openai): harden image stream relay * fix(openai): return image JSON errors * fix(relay): reset stream status per scanner run * fix(relay): drop upstream credit passthrough * fix(openai): keep image errors minimal * fix(openai): keep image error status from response --------- Co-authored-by: CaIon <i@caion.me>
This commit is contained in:
+2
-2
@@ -26,7 +26,7 @@ type ImageRequest struct {
|
|||||||
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
// Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Images json.RawMessage `json:"images,omitempty"`
|
Images json.RawMessage `json:"images,omitempty"`
|
||||||
Mask json.RawMessage `json:"mask,omitempty"`
|
Mask json.RawMessage `json:"mask,omitempty"`
|
||||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||||
@@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||||
return false
|
return i.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) SetModelName(modelName string) {
|
func (i *ImageRequest) SetModelName(modelName string) {
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestImageRequestStreamJSON verifies that image requests preserve stream=true.
|
||||||
|
func TestImageRequestStreamJSON(t *testing.T) {
|
||||||
|
var req ImageRequest
|
||||||
|
require.NoError(t, req.UnmarshalJSON([]byte(`{"model":"gpt-image-1","prompt":"draw a cat","stream":true}`)))
|
||||||
|
|
||||||
|
require.True(t, req.Stream)
|
||||||
|
require.True(t, req.IsStream(nil))
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -439,10 +440,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
// 使用已解析的 multipart 表单,避免重复解析
|
// 使用已解析的 multipart 表单,避免重复解析
|
||||||
mf := c.Request.MultipartForm
|
mf := c.Request.MultipartForm
|
||||||
if mf == nil {
|
if mf == nil {
|
||||||
if _, err := c.MultipartForm(); err != nil {
|
form, err := common.ParseMultipartFormReusable(c)
|
||||||
return nil, errors.New("failed to parse multipart form")
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
|
||||||
}
|
}
|
||||||
mf = c.Request.MultipartForm
|
c.Request.MultipartForm = form
|
||||||
|
c.Request.PostForm = url.Values(form.Value)
|
||||||
|
mf = form
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入所有非文件字段
|
// 写入所有非文件字段
|
||||||
@@ -625,7 +629,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
|
if info.IsStream {
|
||||||
|
usage, err = OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
} else {
|
||||||
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
||||||
|
}
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConvertImageEditRequestKeepsValidMultipartStreamFields verifies multipart replay.
|
||||||
|
func TestConvertImageEditRequestKeepsValidMultipartStreamFields(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
||||||
|
require.NoError(t, writer.WriteField("stream", "true"))
|
||||||
|
require.NoError(t, writer.WriteField("partial_images", "3"))
|
||||||
|
part, err := writer.CreateFormFile("image", "input.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write([]byte("fake image"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
RelayMode: relayconstant.RelayModeImagesEdits,
|
||||||
|
}
|
||||||
|
request := dto.ImageRequest{
|
||||||
|
Model: "gpt-image-1",
|
||||||
|
Prompt: "edit this image",
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
convertedBody, ok := converted.(*bytes.Buffer)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
||||||
|
replayedRequest.Header.Set("Content-Type", contentType)
|
||||||
|
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
||||||
|
|
||||||
|
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
|
||||||
|
require.Equal(t, "edit this image", replayedRequest.PostForm.Get("prompt"))
|
||||||
|
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
||||||
|
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
|
||||||
|
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
||||||
|
|
||||||
|
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer file.Close()
|
||||||
|
fileBytes, err := io.ReadAll(file)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("fake image"), fileBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing verifies fallback parsing.
|
||||||
|
func TestConvertImageEditRequestParsesReusableMultipartWhenFormIsMissing(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "edit without pre-parsed form"))
|
||||||
|
require.NoError(t, writer.WriteField("stream", "true"))
|
||||||
|
part, err := writer.CreateFormFile("image", "input.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write([]byte("fake image"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
storage, err := common.GetBodyStorage(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Request.Body = io.NopCloser(storage)
|
||||||
|
c.Request.MultipartForm = nil
|
||||||
|
c.Request.PostForm = nil
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
RelayMode: relayconstant.RelayModeImagesEdits,
|
||||||
|
}
|
||||||
|
request := dto.ImageRequest{
|
||||||
|
Model: "gpt-image-1",
|
||||||
|
Prompt: "edit without pre-parsed form",
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
convertedBody, ok := converted.(*bytes.Buffer)
|
||||||
|
require.True(t, ok)
|
||||||
|
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
||||||
|
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
||||||
|
require.Equal(t, "edit without pre-parsed form", replayedRequest.PostForm.Get("prompt"))
|
||||||
|
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
||||||
|
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,253 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
oldTimeout := constant.StreamingTimeout
|
||||||
|
constant.StreamingTimeout = 30
|
||||||
|
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||||
|
|
||||||
|
body := strings.Join([]string{
|
||||||
|
`event: image_generation.partial_image`,
|
||||||
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||||
|
``,
|
||||||
|
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
|
||||||
|
``,
|
||||||
|
`data: [DONE]`,
|
||||||
|
``,
|
||||||
|
}, "\n")
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, usage.PromptTokens)
|
||||||
|
require.Equal(t, 4, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 7, usage.TotalTokens)
|
||||||
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||||
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||||
|
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||||
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1)
|
||||||
|
body := "data: " + payload + "\n\n"
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Contains(t, recorder.Body.String(), payload)
|
||||||
|
require.NotNil(t, info.StreamStatus)
|
||||||
|
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, usage.PromptTokens)
|
||||||
|
require.Equal(t, 4, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 7, usage.TotalTokens)
|
||||||
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||||
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||||
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||||
|
require.Empty(t, recorder.Header().Get("Content-Length"))
|
||||||
|
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiHandlerWithUsage(c, info, resp)
|
||||||
|
require.Nil(t, usage)
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||||
|
oaiError := err.ToOpenAIError()
|
||||||
|
require.Equal(t, "content moderation failed", oaiError.Message)
|
||||||
|
require.Equal(t, "upstream_error", oaiError.Type)
|
||||||
|
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
||||||
|
require.Empty(t, recorder.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
require.Nil(t, usage)
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||||
|
oaiError := err.ToOpenAIError()
|
||||||
|
require.Equal(t, "image edit failed", oaiError.Message)
|
||||||
|
require.Empty(t, recorder.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
||||||
|
oldMode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||||
|
|
||||||
|
body := strings.Join([]string{
|
||||||
|
`event: image_generation.partial_image`,
|
||||||
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||||
|
``,
|
||||||
|
`event: error`,
|
||||||
|
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
|
||||||
|
``,
|
||||||
|
}, "\n")
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||||
|
IsStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.NotNil(t, info.StreamStatus)
|
||||||
|
require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
||||||
|
require.True(t, info.StreamStatus.HasErrors())
|
||||||
|
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||||
|
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
||||||
|
require.Contains(t, recorder.Body.String(), `event: error`)
|
||||||
|
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) {
|
||||||
|
usage := &dto.Usage{
|
||||||
|
InputTokens: 5000,
|
||||||
|
OutputTokens: 4000,
|
||||||
|
InputTokensDetails: &dto.InputTokenDetails{
|
||||||
|
CachedCreationTokens: 200,
|
||||||
|
ImageTokens: 1000,
|
||||||
|
TextTokens: 4000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizeOpenAIUsage(usage)
|
||||||
|
|
||||||
|
require.Equal(t, 5000, usage.PromptTokens)
|
||||||
|
require.Equal(t, 4000, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 9000, usage.TotalTokens)
|
||||||
|
require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens)
|
||||||
|
require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens)
|
||||||
|
require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens)
|
||||||
|
}
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
@@ -566,27 +569,278 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||||
|
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
// Once we've written to the client, we should not return errors anymore
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
// because the upstream has already consumed resources and returned content
|
|
||||||
// We should still perform billing even if parsing fails
|
|
||||||
// format
|
|
||||||
if usageResp.InputTokens > 0 {
|
|
||||||
usageResp.PromptTokens += usageResp.InputTokens
|
|
||||||
}
|
|
||||||
if usageResp.OutputTokens > 0 {
|
|
||||||
usageResp.CompletionTokens += usageResp.OutputTokens
|
|
||||||
}
|
|
||||||
if usageResp.InputTokensDetails != nil {
|
|
||||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
|
||||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
|
||||||
}
|
|
||||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||||
return &usageResp.Usage, nil
|
return &usageResp.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIUsage(usage *dto.Usage) {
|
||||||
|
if usage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if usage.InputTokens != 0 {
|
||||||
|
usage.PromptTokens = usage.InputTokens
|
||||||
|
}
|
||||||
|
if usage.OutputTokens != 0 {
|
||||||
|
usage.CompletionTokens = usage.OutputTokens
|
||||||
|
}
|
||||||
|
if usage.InputTokensDetails != nil {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
|
||||||
|
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
|
||||||
|
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
|
||||||
|
}
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
if resp == nil || resp.Body == nil {
|
||||||
|
logger.LogError(c, "invalid image stream response")
|
||||||
|
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return OpenaiHandlerWithUsage(c, info, resp)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return OpenaiImageJSONAsStreamHandler(c, info, resp)
|
||||||
|
}
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
var lastStreamData []byte
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
if info != nil && info.StreamStatus == nil {
|
||||||
|
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
currentEvent := ""
|
||||||
|
var readErr error
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
readErr = err
|
||||||
|
if len(line) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line = strings.TrimSuffix(line, "\n")
|
||||||
|
line = strings.TrimSuffix(line, "\r")
|
||||||
|
if strings.HasPrefix(line, "event:") {
|
||||||
|
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
if data == "[DONE]" {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||||
|
}
|
||||||
|
} else if data != "" {
|
||||||
|
if info != nil {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
info.ReceivedResponseCount++
|
||||||
|
}
|
||||||
|
lastStreamData = common.StringToByteSlice(data)
|
||||||
|
if info != nil && info.StreamStatus != nil && isOpenAIImageStreamErrorEvent(currentEvent, lastStreamData) {
|
||||||
|
info.StreamStatus.RecordError(extractOpenAIImageStreamErrorMessage(lastStreamData))
|
||||||
|
}
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
if err := common.Unmarshal(lastStreamData, &usageResp); err == nil {
|
||||||
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
|
if service.ValidUsage(&usageResp.Usage) {
|
||||||
|
usage = &usageResp.Usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := c.Writer.Write(append([]byte(line), '\n')); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
if line == "" {
|
||||||
|
if err := helper.FlushWriter(c); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
currentEvent = ""
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
if readErr != nil && readErr != io.EOF {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, readErr)
|
||||||
|
} else if info.StreamStatus.HasErrors() {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonHandlerStop, fmt.Errorf("upstream image stream returned error event"))
|
||||||
|
} else if info.StreamStatus.EndReason == relaycommon.StreamEndReasonNone {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = helper.FlushWriter(c)
|
||||||
|
|
||||||
|
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAIImageStreamErrorEvent(eventName string, data []byte) bool {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(eventName), "error") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !json.Valid(data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Error json.RawMessage `json:"error"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(data, &payload); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
|
||||||
|
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImageStreamErrorMessage(data []byte) string {
|
||||||
|
if len(data) == 0 || !json.Valid(data) {
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error json.RawMessage `json:"error"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(data, &payload); err != nil {
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(payload.Message); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
if len(payload.Error) > 0 {
|
||||||
|
var nested struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(payload.Error, &nested); err == nil {
|
||||||
|
if msg := strings.TrimSpace(nested.Message); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "upstream image stream returned error event"
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var imageResp dto.ImageResponse
|
||||||
|
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
|
||||||
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
_ = common.Unmarshal(responseBody, &usageResp)
|
||||||
|
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||||
|
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||||
|
}
|
||||||
|
normalizeOpenAIUsage(&usageResp.Usage)
|
||||||
|
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
created := imageResp.Created
|
||||||
|
if created == 0 {
|
||||||
|
created = time.Now().Unix()
|
||||||
|
}
|
||||||
|
if info != nil {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
}
|
||||||
|
for _, image := range imageResp.Data {
|
||||||
|
payload := map[string]any{
|
||||||
|
"type": "image_generation.completed",
|
||||||
|
"created_at": created,
|
||||||
|
}
|
||||||
|
if image.Url != "" {
|
||||||
|
payload["url"] = image.Url
|
||||||
|
}
|
||||||
|
if image.B64Json != "" {
|
||||||
|
payload["b64_json"] = image.B64Json
|
||||||
|
}
|
||||||
|
if image.RevisedPrompt != "" {
|
||||||
|
payload["revised_prompt"] = image.RevisedPrompt
|
||||||
|
}
|
||||||
|
if service.ValidUsage(&usageResp.Usage) {
|
||||||
|
payload["usage"] = usageResp.Usage
|
||||||
|
}
|
||||||
|
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||||
|
if info != nil && info.StreamStatus != nil {
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
if info != nil {
|
||||||
|
info.ReceivedResponseCount += len(imageResp.Data)
|
||||||
|
if info.StreamStatus == nil {
|
||||||
|
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||||
|
}
|
||||||
|
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||||
|
}
|
||||||
|
return &usageResp.Usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
|
||||||
|
data, err := common.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if eventName != "" {
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return helper.FlushWriter(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenaiImageStreamDone(c *gin.Context) error {
|
||||||
|
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return helper.FlushWriter(c)
|
||||||
|
}
|
||||||
|
|
||||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||||
if info == nil || usage == nil {
|
if info == nil || usage == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetAndValidOpenAIImageRequestMultipartStream verifies reusable image edit parsing.
|
||||||
|
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
||||||
|
require.NoError(t, writer.WriteField("stream", "true"))
|
||||||
|
require.NoError(t, writer.WriteField("n", "1"))
|
||||||
|
part, err := writer.CreateFormFile("image", "input.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write([]byte("fake image"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
originalBody := body.String()
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, req.Stream)
|
||||||
|
require.True(t, req.IsStream(c))
|
||||||
|
|
||||||
|
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, originalBody, string(bodyAfterValidation))
|
||||||
|
|
||||||
|
form, err := common.ParseMultipartFormReusable(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
|
||||||
|
require.Len(t, form.File["image"], 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue verifies stream validation.
|
||||||
|
func TestGetAndValidOpenAIImageRequestMultipartStreamInvalidValue(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||||
|
require.NoError(t, writer.WriteField("stream", "notabool"))
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
_, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid stream value")
|
||||||
|
}
|
||||||
@@ -631,7 +631,7 @@ func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
|||||||
assert.NotNil(t, info.StreamStatus)
|
assert.NotNil(t, info.StreamStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
func TestStreamScannerHandler_StreamStatus_ReplacesPreInitialized(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
body := buildSSEBody(5)
|
body := buildSSEBody(5)
|
||||||
@@ -643,7 +643,7 @@ func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
|||||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||||
|
|
||||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||||
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
assert.Equal(t, 0, info.StreamStatus.TotalErrorCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
@@ -144,16 +146,25 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
|||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesEdits:
|
||||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
_, err := c.MultipartForm()
|
form, err := common.ParseMultipartFormReusable(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||||
}
|
}
|
||||||
formData := c.Request.PostForm
|
formData := url.Values(form.Value)
|
||||||
|
c.Request.MultipartForm = form
|
||||||
|
c.Request.PostForm = formData
|
||||||
imageRequest.Prompt = formData.Get("prompt")
|
imageRequest.Prompt = formData.Get("prompt")
|
||||||
imageRequest.Model = formData.Get("model")
|
imageRequest.Model = formData.Get("model")
|
||||||
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))
|
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))
|
||||||
imageRequest.Quality = formData.Get("quality")
|
imageRequest.Quality = formData.Get("quality")
|
||||||
imageRequest.Size = formData.Get("size")
|
imageRequest.Size = formData.Get("size")
|
||||||
|
if streamValue := strings.TrimSpace(formData.Get("stream")); streamValue != "" {
|
||||||
|
stream, err := strconv.ParseBool(streamValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid stream value: %w", err)
|
||||||
|
}
|
||||||
|
imageRequest.Stream = stream
|
||||||
|
}
|
||||||
if imageValue := formData.Get("image"); imageValue != "" {
|
if imageValue := formData.Get("image"); imageValue != "" {
|
||||||
imageRequest.Image, _ = common.Marshal(imageValue)
|
imageRequest.Image, _ = common.Marshal(imageValue)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user