package openai import ( "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) body := strings.Join([]string{ `event: image_generation.partial_image`, `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, ``, `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`, ``, `data: [DONE]`, ``, }, "\n") recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"text/event-stream"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: true, } usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) require.Equal(t, 3, usage.PromptTokens) require.Equal(t, 4, usage.CompletionTokens) require.Equal(t, 7, usage.TotalTokens) require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens) require.Equal(t, 1, usage.PromptTokensDetails.TextTokens) require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`) require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`) require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`) require.Contains(t, recorder.Body.String(), `data: [DONE]`) require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) } func TestOpenaiImageStreamHandlerForwardsLargeSSELine(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) payload := strings.Repeat("x", helper.DefaultMaxScannerBufferSize+1) body := "data: " + payload + "\n\n" recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"text/event-stream"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: true, } usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) require.NotNil(t, usage) require.Contains(t, recorder.Body.String(), payload) require.NotNil(t, info.StreamStatus) require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) } func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}` recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"application/json"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: true, } usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) require.Equal(t, 3, usage.PromptTokens) require.Equal(t, 4, usage.CompletionTokens) require.Equal(t, 7, usage.TotalTokens) require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens) require.Equal(t, 1, usage.PromptTokensDetails.TextTokens) require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type")) require.Empty(t, recorder.Header().Get("Content-Length")) require.Contains(t, recorder.Body.String(), `event: image_generation.completed`) require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`) require.Contains(t, recorder.Body.String(), `"b64_json":"final"`) require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`) require.Contains(t, recorder.Body.String(), `data: [DONE]`) } func TestOpenaiHandlerWithUsageReturnsImageJSONError(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"application/json"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: false, } usage, err := OpenaiHandlerWithUsage(c, info, resp) require.Nil(t, usage) require.NotNil(t, err) require.Equal(t, http.StatusOK, err.StatusCode) oaiError := err.ToOpenAIError() require.Equal(t, "content moderation failed", oaiError.Message) require.Equal(t, "upstream_error", oaiError.Type) require.Equal(t, "content_moderation_failed", oaiError.Code) require.Empty(t, recorder.Body.String()) } func TestOpenaiImageStreamHandlerReturnsJSONErrorFallback(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) body := `{"error":{"message":"image edit failed","type":"upstream_error","code":"content_moderation_failed","status":502}}` recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"application/json"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: true, } usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, usage) require.NotNil(t, err) require.Equal(t, http.StatusOK, err.StatusCode) oaiError := err.ToOpenAIError() require.Equal(t, "image edit failed", oaiError.Message) require.Empty(t, recorder.Body.String()) } func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) { oldMode := gin.Mode() gin.SetMode(gin.TestMode) t.Cleanup(func() { gin.SetMode(oldMode) }) body := strings.Join([]string{ `event: image_generation.partial_image`, `data: {"type":"image_generation.partial_image","b64_json":"partial"}`, ``, `event: error`, `data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`, ``, }, "\n") recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: http.Header{"Content-Type": []string{"text/event-stream"}}, } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, IsStream: true, } usage, err := OpenaiImageStreamHandler(c, info, resp) require.Nil(t, err) require.NotNil(t, usage) require.NotNil(t, info.StreamStatus) require.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) require.True(t, info.StreamStatus.HasErrors()) require.Equal(t, 1, info.StreamStatus.TotalErrorCount()) require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR") require.Contains(t, recorder.Body.String(), `event: error`) require.Contains(t, recorder.Body.String(), `stream ID 77`) } func TestNormalizeOpenAIUsageMapsImageTokenDetailsWithoutDoubleCounting(t *testing.T) { usage := &dto.Usage{ InputTokens: 5000, OutputTokens: 4000, InputTokensDetails: &dto.InputTokenDetails{ CachedCreationTokens: 200, ImageTokens: 1000, TextTokens: 4000, }, } normalizeOpenAIUsage(usage) require.Equal(t, 5000, usage.PromptTokens) require.Equal(t, 4000, usage.CompletionTokens) require.Equal(t, 9000, usage.TotalTokens) require.Equal(t, 200, usage.PromptTokensDetails.CachedCreationTokens) require.Equal(t, 1000, usage.PromptTokensDetails.ImageTokens) require.Equal(t, 4000, usage.PromptTokensDetails.TextTokens) }