fix: use actual user id for channel tests (#5109)
This commit is contained in:
@@ -57,7 +57,24 @@ func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointTyp
|
|||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
func resolveChannelTestUserID(c *gin.Context) (int, error) {
|
||||||
|
if c != nil {
|
||||||
|
if userID := c.GetInt("id"); userID > 0 {
|
||||||
|
return userID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rootUser model.User
|
||||||
|
if err := model.DB.Select("id").Where("role = ?", common.RoleRootUser).First(&rootUser).Error; err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to resolve channel test user: %w", err)
|
||||||
|
}
|
||||||
|
if rootUser.Id == 0 {
|
||||||
|
return 0, errors.New("failed to resolve channel test user")
|
||||||
|
}
|
||||||
|
return rootUser.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, testUserID int, testModel string, endpointType string, isStream bool) testResult {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
var unsupportedTestChannelTypes = []int{
|
var unsupportedTestChannelTypes = []int{
|
||||||
constant.ChannelTypeMidjourney,
|
constant.ChannelTypeMidjourney,
|
||||||
@@ -143,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}
|
}
|
||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(testUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
localErr: err,
|
localErr: err,
|
||||||
@@ -151,13 +168,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.WriteContext(c)
|
cache.WriteContext(c)
|
||||||
c.Set("id", 1)
|
c.Set("id", testUserID)
|
||||||
|
|
||||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
group, _ := model.GetUserGroup(1, false)
|
group, _ := model.GetUserGroup(testUserID, false)
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
|
|
||||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
@@ -484,7 +501,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
|
|||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
other := buildTestLogOther(c, info, priceData, usage, tieredResult)
|
||||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
model.RecordConsumeLog(c, testUserID, model.RecordConsumeLogParams{
|
||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
PromptTokens: usage.PromptTokens,
|
PromptTokens: usage.PromptTokens,
|
||||||
CompletionTokens: usage.CompletionTokens,
|
CompletionTokens: usage.CompletionTokens,
|
||||||
@@ -834,8 +851,13 @@ func TestChannel(c *gin.Context) {
|
|||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
endpointType := c.Query("endpoint_type")
|
endpointType := c.Query("endpoint_type")
|
||||||
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
||||||
|
testUserID, err := resolveChannelTestUserID(c)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
result := testChannel(channel, testModel, endpointType, isStream)
|
result := testChannel(channel, testUserID, testModel, endpointType, isStream)
|
||||||
if result.localErr != nil {
|
if result.localErr != nil {
|
||||||
resp := gin.H{
|
resp := gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -872,6 +894,10 @@ var testAllChannelsLock sync.Mutex
|
|||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
func testAllChannels(notify bool) error {
|
func testAllChannels(notify bool) error {
|
||||||
|
testUserID, err := resolveChannelTestUserID(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
if testAllChannelsRunning {
|
if testAllChannelsRunning {
|
||||||
@@ -902,7 +928,7 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
result := testChannel(channel, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
result := testChannel(channel, testUserID, "", "", shouldUseStreamForAutomaticChannelTest(channel))
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
|
|||||||
@@ -69,3 +69,14 @@ func TestBuildTestLogOtherInjectsTieredInfo(t *testing.T) {
|
|||||||
require.Equal(t, "base", other["matched_tier"])
|
require.Equal(t, "base", other["matched_tier"])
|
||||||
require.NotEmpty(t, other["expr_b64"])
|
require.NotEmpty(t, other["expr_b64"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolveChannelTestUserIDUsesRequestUser(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
ctx.Set("id", 2)
|
||||||
|
|
||||||
|
userID, err := resolveChannelTestUserID(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2, userID)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user