From 30025aeba3e6ec05d72ac32e9ea54cd4aecb88b4 Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Tue, 26 May 2026 12:32:20 +0800 Subject: [PATCH] fix: use actual user id for channel tests (#5109) --- controller/channel-test.go | 40 +++++++++++++++++++----- controller/channel_test_internal_test.go | 11 +++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index b225585e..037b8496 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp return normalized } -func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult { +func resolveChannelTestUserID(c *gin.Context) (int, error) { + if c != nil { + if userID := c.GetInt("id"); userID > 0 { + return userID, nil + } + } + + var rootUser model.User + if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil { + return 0, fmt.Errorf("failed to resolve channel test user: %w", err) + } + if rootUser.Id == 0 { + return 0, errors.New("failed to resolve channel test user") + } + return rootUser.Id, nil +} + +func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult { tik := time.Now() var unsupportedTestChannelTypes = []int{ constant.ChannelTypeMidjourney, @@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, Header: make(http.Header), } - cache, err := model.GetUserCache(1) + cache, err := model.GetUserCache(testUserID) if err != nil { return testResult{ localErr: err, @@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, } } cache.WriteContext(c) - c.Set("id", 1) + c.Set("id", testUserID) //c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) - group, _ := model.GetUserGroup(1, false) + group, _ := model.GetUserGroup(testUserID, false) c.Set("group", group) newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) @@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 other := buildTestLogOther(c, info, priceData, usage, tieredResult) - model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ + model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{ ChannelId: channel.Id, PromptTokens: usage.PromptTokens, CompletionTokens: usage.CompletionTokens, @@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) { testModel := c.Query("model") endpointType := c.Query("endpoint_type") isStream, _ := strconv.ParseBool(c.Query("stream")) + testUserID, err := resolveChannelTestUserID(c) + if err != nil { + common.ApiError(c, err) + return + } tik := time.Now() - result := testChannel(channel, testModel, endpointType, isStream) + result := testChannel(channel, testUserID, testModel, endpointType, isStream) if result.localErr != nil { resp := gin.H{ "success": false, @@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func testAllChannels(notify bool) error { + testUserID, err := resolveChannelTestUserID(nil) + if err != nil { + return err + } testAllChannelsLock.Lock() if testAllChannelsRunning { @@ -902,7 +928,7 @@ func testAllChannels(notify bool) error { } isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel)) + result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel)) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() diff --git a/controller/channel_test_internal_test.go b/controller/channel_test_internal_test.go index 9c26d623..02540801 100644 --- a/controller/channel_test_internal_test.go +++ b/controller/channel_test_internal_test.go @@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) { require.Equal(t, "base", other["matched_tier"]) require.NotEmpty(t, other["expr_b64"]) } + +func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Set("id", 2) + + userID, err := resolveChannelTestUserID(ctx) + + require.NoError(t, err) + require.Equal(t, 2, userID) +}