fix: resolve model owned_by from active channels (#4416)

* fix: resolve model owned_by from active channels

* fix: respect token group when resolving model owners
This commit is contained in:
yyhhyyyyyy
2026-05-21 11:16:17 +08:00
committed by GitHub
parent 6f11d19877
commit 006e801652
5 changed files with 405 additions and 43 deletions
+57
View File
@@ -2,6 +2,7 @@ package model
import (
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
return result, nil
}
func normalizeLookupValues(values []string) []string {
seen := make(map[string]struct{}, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
normalized = append(normalized, value)
}
return normalized
}
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
result := make(map[string]int)
modelNames = normalizeLookupValues(modelNames)
if len(modelNames) == 0 {
return result, nil
}
type row struct {
Model string
ChannelType int
}
var rows []row
query := DB.Table("abilities").
Select("abilities.model as model, channels.type as channel_type").
Joins("JOIN channels ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
Order("COALESCE(abilities.priority, 0) DESC").
Order("abilities.weight DESC").
Order("abilities.channel_id ASC")
groups = normalizeLookupValues(groups)
if len(groups) > 0 {
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
}
if err := query.Scan(&rows).Error; err != nil {
return nil, err
}
for _, r := range rows {
if _, ok := result[r.Model]; ok {
continue
}
result[r.Model] = r.ChannelType
}
return result, nil
}
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
+141
View File
@@ -0,0 +1,141 @@
package model
import (
"fmt"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/stretchr/testify/require"
)
func clearPreferredOwnerTables(t *testing.T) {
t.Helper()
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
}
func insertPreferredOwnerCandidate(
t *testing.T,
channelID int,
modelName string,
group string,
channelType int,
priority int64,
weight uint,
channelStatus int,
abilityEnabled bool,
) {
t.Helper()
require.NoError(t, DB.Create(&Channel{
Id: channelID,
Type: channelType,
Key: fmt.Sprintf("key-%d", channelID),
Status: channelStatus,
Name: fmt.Sprintf("channel-%d", channelID),
}).Error)
require.NoError(t, DB.Create(&Ability{
Group: group,
Model: modelName,
ChannelId: channelID,
Enabled: abilityEnabled,
Priority: &priority,
Weight: weight,
}).Error)
}
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
const modelName = "gpt-5.4"
tests := []struct {
name string
setup func(t *testing.T)
groups []string
expected int
found bool
}{
{
name: "openai only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "codex only",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "priority wins",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "weight wins when priority is equal",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeCodex,
found: true,
},
{
name: "channel id stabilizes exact ties",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "group filter excludes other groups",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
},
groups: []string{"default"},
expected: constant.ChannelTypeOpenAI,
found: true,
},
{
name: "disabled candidates are ignored",
setup: func(t *testing.T) {
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
},
groups: []string{"default"},
found: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearPreferredOwnerTables(t)
tt.setup(t)
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
require.NoError(t, err)
got, ok := owners[modelName]
require.Equal(t, tt.found, ok)
if tt.found {
require.Equal(t, tt.expected, got)
}
})
}
}
+2
View File
@@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
&Token{},
&Log{},
&Channel{},
&Ability{},
&TopUp{},
&SubscriptionPlan{},
&SubscriptionOrder{},
@@ -60,6 +61,7 @@ func truncateTables(t *testing.T) {
DB.Exec("DELETE FROM tokens")
DB.Exec("DELETE FROM logs")
DB.Exec("DELETE FROM channels")
DB.Exec("DELETE FROM abilities")
DB.Exec("DELETE FROM top_ups")
DB.Exec("DELETE FROM subscription_orders")
DB.Exec("DELETE FROM subscription_plans")