d2576ddcd3
* fix(openai): support streaming image relay * fix(openai): keep image edit multipart body reusable * test(openai): cover image stream usage details * test(openai): cover image edit fallback stream field * fix(openai): wrap image json fallback as stream * fix(relay): support OpenAI image streaming * fix(openai): record image stream upstream error events * fix(openai): harden image stream relay * fix(openai): return image JSON errors * fix(relay): reset stream status per scanner run * fix(relay): drop upstream credit passthrough * fix(openai): keep image errors minimal * fix(openai): keep image error status from response --------- Co-authored-by: CaIon <i@caion.me>
254 lines
9.0 KiB
Go
254 lines
9.0 KiB
Go
package openai
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/relay/helper"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
oldTimeout := constant.StreamingTimeout
|
|
constant.StreamingTimeout = 30
|
|
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
|
|
|
body := strings.Join([]string{
|
|
`event: image_generation.partial_image`,
|
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
|
``,
|
|
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
|
|
``,
|
|
`data: [DONE]`,
|
|
``,
|
|
}, "\n")
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: true,
|
|
}
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.Equal(t, 3, usage.PromptTokens)
|
|
require.Equal(t, 4, usage.CompletionTokens)
|
|
require.Equal(t, 7, usage.TotalTokens)
|
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
|
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
|
|
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
|
|
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
|
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
|
}
|
|
|
|
func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1)
|
|
body := "data: " + payload + "\n\n"
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: true,
|
|
}
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.NotNil(t, usage)
|
|
require.Contains(t, recorder.Body.String(), payload)
|
|
require.NotNil(t, info.StreamStatus)
|
|
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
|
}
|
|
|
|
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: true,
|
|
}
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.Equal(t, 3, usage.PromptTokens)
|
|
require.Equal(t, 4, usage.CompletionTokens)
|
|
require.Equal(t, 7, usage.TotalTokens)
|
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
|
require.Empty(t, recorder.Header().Get("Content-Length"))
|
|
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
|
|
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
|
|
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
|
|
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
|
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
|
}
|
|
|
|
func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: false,
|
|
}
|
|
|
|
usage, err := OpenaiHandlerWithUsage(c, info, resp)
|
|
require.Nil(t, usage)
|
|
require.NotNil(t, err)
|
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
|
oaiError := err.ToOpenAIError()
|
|
require.Equal(t, "content moderation failed", oaiError.Message)
|
|
require.Equal(t, "upstream_error", oaiError.Type)
|
|
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
|
require.Empty(t, recorder.Body.String())
|
|
}
|
|
|
|
func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: true,
|
|
}
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, usage)
|
|
require.NotNil(t, err)
|
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
|
oaiError := err.ToOpenAIError()
|
|
require.Equal(t, "image edit failed", oaiError.Message)
|
|
require.Empty(t, recorder.Body.String())
|
|
}
|
|
|
|
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := strings.Join([]string{
|
|
`event: image_generation.partial_image`,
|
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
|
``,
|
|
`event: error`,
|
|
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
|
|
``,
|
|
}, "\n")
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: true,
|
|
}
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.NotNil(t, usage)
|
|
require.NotNil(t, info.StreamStatus)
|
|
require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
|
require.True(t, info.StreamStatus.HasErrors())
|
|
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
|
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
|
require.Contains(t, recorder.Body.String(), `event: error`)
|
|
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
|
}
|
|
|
|
func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) {
|
|
usage := &dto.Usage{
|
|
InputTokens: 5000,
|
|
OutputTokens: 4000,
|
|
InputTokensDetails: &dto.InputTokenDetails{
|
|
CachedCreationTokens: 200,
|
|
ImageTokens: 1000,
|
|
TextTokens: 4000,
|
|
},
|
|
}
|
|
|
|
normalizeOpenAIUsage(usage)
|
|
|
|
require.Equal(t, 5000, usage.PromptTokens)
|
|
require.Equal(t, 4000, usage.CompletionTokens)
|
|
require.Equal(t, 9000, usage.TotalTokens)
|
|
require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens)
|
|
require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens)
|
|
require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens)
|
|
}
|