59a93cf5c7
Route OpenAI image streaming through shared stream handling, split image/realtime/usage helpers for maintainability, and include the related image request and rate limit updates.
174 lines
7.1 KiB
Go
174 lines
7.1 KiB
Go
package openai
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) {
|
|
t.Helper()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{contentType}},
|
|
}
|
|
info := &relaycommon.RelayInfo{
|
|
ChannelMeta: &relaycommon.ChannelMeta{},
|
|
IsStream: isStream,
|
|
}
|
|
return c, recorder, resp, info
|
|
}
|
|
|
|
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path:
|
|
// chunks are forwarded with rebuilt event lines, usage is extracted and
|
|
// normalized (input_tokens -> prompt_tokens with details), and [DONE] is
|
|
// re-emitted to the client.
|
|
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
oldTimeout := constant.StreamingTimeout
|
|
constant.StreamingTimeout = 30
|
|
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
|
|
|
body := strings.Join([]string{
|
|
`event: image_generation.partial_image`,
|
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
|
``,
|
|
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
|
|
``,
|
|
`data: [DONE]`,
|
|
``,
|
|
}, "\n")
|
|
|
|
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.Equal(t, 3, usage.PromptTokens)
|
|
require.Equal(t, 4, usage.CompletionTokens)
|
|
require.Equal(t, 7, usage.TotalTokens)
|
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
|
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
|
|
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
|
|
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
|
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
|
}
|
|
|
|
// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback:
|
|
// a JSON upstream response is wrapped into pseudo-SSE completed events.
|
|
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
|
|
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.Equal(t, 3, usage.PromptTokens)
|
|
require.Equal(t, 4, usage.CompletionTokens)
|
|
require.Equal(t, 7, usage.TotalTokens)
|
|
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
|
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
|
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
|
require.Empty(t, recorder.Header().Get("Content-Length"))
|
|
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
|
|
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
|
|
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
|
|
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
|
|
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
|
}
|
|
|
|
// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both
|
|
// entry points: the non-streaming handler and the stream handler's non-SSE
|
|
// fallback. Neither must leak the error body to the client.
|
|
func TestOpenaiImageHandlersReturnJSONError(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
|
|
|
t.Run("non-streaming handler", func(t *testing.T) {
|
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", false)
|
|
|
|
usage, err := OpenaiImageHandler(c, info, resp)
|
|
require.Nil(t, usage)
|
|
require.NotNil(t, err)
|
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
|
oaiError := err.ToOpenAIError()
|
|
require.Equal(t, "content moderation failed", oaiError.Message)
|
|
require.Equal(t, "upstream_error", oaiError.Type)
|
|
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
|
require.Empty(t, recorder.Body.String())
|
|
})
|
|
|
|
t.Run("stream handler JSON fallback", func(t *testing.T) {
|
|
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, usage)
|
|
require.NotNil(t, err)
|
|
require.Equal(t, http.StatusOK, err.StatusCode)
|
|
require.Equal(t, "content moderation failed", err.ToOpenAIError().Message)
|
|
require.Empty(t, recorder.Body.String())
|
|
})
|
|
}
|
|
|
|
// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error
|
|
// event inside the SSE stream is recorded as a soft error while the payload is
|
|
// still forwarded to the client.
|
|
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
|
oldMode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(oldMode) })
|
|
|
|
oldTimeout := constant.StreamingTimeout
|
|
constant.StreamingTimeout = 30
|
|
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
|
|
|
body := strings.Join([]string{
|
|
`event: image_generation.partial_image`,
|
|
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
|
``,
|
|
`event: error`,
|
|
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
|
|
``,
|
|
}, "\n")
|
|
|
|
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
|
|
|
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
|
require.Nil(t, err)
|
|
require.NotNil(t, usage)
|
|
require.NotNil(t, info.StreamStatus)
|
|
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
|
require.True(t, info.StreamStatus.HasErrors())
|
|
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
|
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
|
// The scanner strips the upstream "event: error" line; the event name is
|
|
// rebuilt from the JSON "type" field (upstream_error). The error message
|
|
// is still forwarded in the data: payload (stream ID 77).
|
|
require.Contains(t, recorder.Body.String(), `event: upstream_error`)
|
|
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
|
}
|