fix: enable channel table server-side sorting (#4600)

This commit is contained in:
yyhhyyyyyy
2026-05-06 18:27:36 +08:00
committed by GitHub
parent f8cf9c57c4
commit dc8deb0c24
9 changed files with 136 additions and 35 deletions
+71 -19
View File
@@ -16,6 +16,7 @@ import (
"github.com/samber/lo"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Channel struct {
@@ -67,6 +68,66 @@ type ChannelInfo struct {
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
type ChannelSortOptions struct {
SortBy string
SortOrder string
IDSort bool
}
var channelSortColumns = map[string]string{
"id": "id",
"name": "name",
"priority": "priority",
"balance": "balance",
"response_time": "response_time",
"test_time": "test_time",
}
func NewChannelSortOptions(sortBy string, sortOrder string, idSort bool) ChannelSortOptions {
normalizedSortBy := strings.ToLower(strings.TrimSpace(sortBy))
normalizedSortOrder := strings.ToLower(strings.TrimSpace(sortOrder))
if _, ok := channelSortColumns[normalizedSortBy]; !ok {
normalizedSortBy = ""
normalizedSortOrder = ""
} else if normalizedSortOrder != "asc" {
normalizedSortOrder = "desc"
}
return ChannelSortOptions{
SortBy: normalizedSortBy,
SortOrder: normalizedSortOrder,
IDSort: idSort,
}
}
func (options ChannelSortOptions) Apply(query *gorm.DB) *gorm.DB {
if columnName, ok := channelSortColumns[options.SortBy]; ok {
return query.Order(clause.OrderByColumn{
Column: clause.Column{Name: columnName},
Desc: options.SortOrder != "asc",
})
}
if options.IDSort {
return query.Order(clause.OrderByColumn{
Column: clause.Column{Name: "id"},
Desc: true,
})
}
return query.Order(clause.OrderByColumn{
Column: clause.Column{Name: "priority"},
Desc: true,
})
}
func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) ChannelSortOptions {
if len(sortOptions) == 0 {
return NewChannelSortOptions("", "", idSort)
}
options := sortOptions[0]
options.IDSort = options.IDSort || idSort
return options
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return common.Marshal(&c)
@@ -260,28 +321,22 @@ func (channel *Channel) SaveWithoutKey() error {
return DB.Omit("key").Save(channel).Error
}
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
var channels []*Channel
var err error
order := "priority desc"
if idSort {
order = "id desc"
}
order := resolveChannelSortOptions(idSort, sortOptions)
if selectAll {
err = DB.Order(order).Find(&channels).Error
err = order.Apply(DB).Find(&channels).Error
} else {
err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
err = order.Apply(DB).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
return channels, err
}
func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, error) {
func GetChannelsByTag(tag string, idSort bool, selectAll bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
var channels []*Channel
order := "priority desc"
if idSort {
order = "id desc"
}
query := DB.Where("tag = ?", tag).Order(order)
order := resolveChannelSortOptions(idSort, sortOptions)
query := order.Apply(DB.Where("tag = ?", tag))
if !selectAll {
query = query.Omit("key")
}
@@ -289,7 +344,7 @@ func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, erro
return channels, err
}
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
func SearchChannels(keyword string, group string, model string, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
var channels []*Channel
modelsCol := "`models`"
@@ -304,10 +359,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
baseURLCol = `"base_url"`
}
order := "priority desc"
if idSort {
order = "id desc"
}
order := resolveChannelSortOptions(idSort, sortOptions)
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit("key")
@@ -331,7 +383,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
// 执行查询
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
if err != nil {
return nil, err
}