fix: use actual user id for channel tests (#5109)

This commit is contained in:
Seefs
2026-05-26 12:32:20 +08:00
committed by GitHub
parent c91ba0c4eb
commit 30025aeba3
2 changed files with 44 additions and 7 deletions
+33 -7
View File
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
return normalized return normalized
} }
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult { func resolveChannelTestUserID(c *gin.Context) (int, error) {
if c != nil {
if userID := c.GetInt("id"); userID > 0 {
return userID, nil
}
}
var rootUser model.User
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
}
if rootUser.Id == 0 {
return 0, errors.New("failed to resolve channel test user")
}
return rootUser.Id, nil
}
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now() tik := time.Now()
var unsupportedTestChannelTypes = []int{ var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourney,
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
Header: make(http.Header), Header: make(http.Header),
} }
cache, err := model.GetUserCache(1) cache, err := model.GetUserCache(testUserID)
if err != nil { if err != nil {
return testResult{ return testResult{
localErr: err, localErr: err,
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
} }
} }
cache.WriteContext(c) cache.WriteContext(c)
c.Set("id", 1) c.Set("id", testUserID)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key) //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false) group, _ := model.GetUserGroup(testUserID, false)
c.Set("group", group) c.Set("group", group)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
other := buildTestLogOther(c, info, priceData, usage, tieredResult) other := buildTestLogOther(c, info, priceData, usage, tieredResult)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
ChannelId: channel.Id, ChannelId: channel.Id,
PromptTokens: usage.PromptTokens, PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens, CompletionTokens: usage.CompletionTokens,
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
testModel := c.Query("model") testModel := c.Query("model")
endpointType := c.Query("endpoint_type") endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream")) isStream, _ := strconv.ParseBool(c.Query("stream"))
testUserID, err := resolveChannelTestUserID(c)
if err != nil {
common.ApiError(c, err)
return
}
tik := time.Now() tik := time.Now()
result := testChannel(channel, testModel, endpointType, isStream) result := testChannel(channel, testUserID, testModel, endpointType, isStream)
if result.localErr != nil { if result.localErr != nil {
resp := gin.H{ resp := gin.H{
"success": false, "success": false,
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error { func testAllChannels(notify bool) error {
testUserID, err := resolveChannelTestUserID(nil)
if err != nil {
return err
}
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
if testAllChannelsRunning { if testAllChannelsRunning {
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
} }
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel)) result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
+11
View File
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
require.Equal(t, "base", other["matched_tier"]) require.Equal(t, "base", other["matched_tier"])
require.NotEmpty(t, other["expr_b64"]) require.NotEmpty(t, other["expr_b64"])
} }
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Set("id", 2)
userID, err := resolveChannelTestUserID(ctx)
require.NoError(t, err)
require.Equal(t, 2, userID)
}