fix: resolve model owned_by from active channels (#4416)
* fix: resolve model owned_by from active channels * fix: respect token group when resolving model owners
This commit is contained in:
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
|
||||
@@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeLookupValues(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
normalized = append(normalized, value)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
|
||||
result := make(map[string]int)
|
||||
modelNames = normalizeLookupValues(modelNames)
|
||||
if len(modelNames) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
Model string
|
||||
ChannelType int
|
||||
}
|
||||
var rows []row
|
||||
|
||||
query := DB.Table("abilities").
|
||||
Select("abilities.model as model, channels.type as channel_type").
|
||||
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
|
||||
Order("COALESCE(abilities.priority, 0) DESC").
|
||||
Order("abilities.weight DESC").
|
||||
Order("abilities.channel_id ASC")
|
||||
|
||||
groups = normalizeLookupValues(groups)
|
||||
if len(groups) > 0 {
|
||||
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
|
||||
}
|
||||
|
||||
if err := query.Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
if _, ok := result[r.Model]; ok {
|
||||
continue
|
||||
}
|
||||
result[r.Model] = r.ChannelType
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func clearPreferredOwnerTables(t *testing.T) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
|
||||
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
|
||||
}
|
||||
|
||||
func insertPreferredOwnerCandidate(
|
||||
t *testing.T,
|
||||
channelID int,
|
||||
modelName string,
|
||||
group string,
|
||||
channelType int,
|
||||
priority int64,
|
||||
weight uint,
|
||||
channelStatus int,
|
||||
abilityEnabled bool,
|
||||
) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Create(&Channel{
|
||||
Id: channelID,
|
||||
Type: channelType,
|
||||
Key: fmt.Sprintf("key-%d", channelID),
|
||||
Status: channelStatus,
|
||||
Name: fmt.Sprintf("channel-%d", channelID),
|
||||
}).Error)
|
||||
require.NoError(t, DB.Create(&Ability{
|
||||
Group: group,
|
||||
Model: modelName,
|
||||
ChannelId: channelID,
|
||||
Enabled: abilityEnabled,
|
||||
Priority: &priority,
|
||||
Weight: weight,
|
||||
}).Error)
|
||||
}
|
||||
|
||||
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
|
||||
const modelName = "gpt-5.4"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T)
|
||||
groups []string
|
||||
expected int
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "openai only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "codex only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "priority wins",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "weight wins when priority is equal",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "channel id stabilizes exact ties",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "group filter excludes other groups",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "disabled candidates are ignored",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearPreferredOwnerTables(t)
|
||||
tt.setup(t)
|
||||
|
||||
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, ok := owners[modelName]
|
||||
require.Equal(t, tt.found, ok)
|
||||
if tt.found {
|
||||
require.Equal(t, tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
|
||||
&Token{},
|
||||
&Log{},
|
||||
&Channel{},
|
||||
&Ability{},
|
||||
&TopUp{},
|
||||
&SubscriptionPlan{},
|
||||
&SubscriptionOrder{},
|
||||
@@ -60,6 +61,7 @@ func truncateTables(t *testing.T) {
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
DB.Exec("DELETE FROM abilities")
|
||||
DB.Exec("DELETE FROM top_ups")
|
||||
DB.Exec("DELETE FROM subscription_orders")
|
||||
DB.Exec("DELETE FROM subscription_plans")
|
||||
|
||||
Reference in New Issue
Block a user