182 lines
5.1 KiB
Go
182 lines
5.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/common/limiter"
|
|
"one-api/setting"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-redis/redis/v8"
|
|
)
|
|
|
|
const (
|
|
ModelRequestRateLimitCountMark = "MRRL"
|
|
ModelRequestRateLimitSuccessCountMark = "MRRLS"
|
|
)
|
|
|
|
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
|
|
if maxCount == 0 {
|
|
return true, nil
|
|
}
|
|
|
|
length, err := rdb.LLen(ctx, key).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if length < int64(maxCount) {
|
|
return true, nil
|
|
}
|
|
|
|
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
|
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
nowTimeStr := time.Now().Format(timeFormat)
|
|
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
subTime := nowTime.Sub(oldTime).Seconds()
|
|
if int64(subTime) < duration {
|
|
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
|
|
if maxCount == 0 {
|
|
return
|
|
}
|
|
|
|
now := time.Now().Format(timeFormat)
|
|
rdb.LPush(ctx, key, now)
|
|
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
|
|
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
|
}
|
|
|
|
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
userId := strconv.Itoa(c.GetInt("id"))
|
|
ctx := context.Background()
|
|
rdb := common.RDB
|
|
|
|
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
|
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
|
if err != nil {
|
|
fmt.Println("检查成功请求数限制失败:", err.Error())
|
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
|
return
|
|
}
|
|
if !allowed {
|
|
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
|
return
|
|
}
|
|
|
|
totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
|
tb := limiter.New(ctx, rdb)
|
|
allowed, err = tb.Allow(
|
|
ctx,
|
|
totalKey,
|
|
limiter.WithCapacity(int64(totalMaxCount)*duration),
|
|
limiter.WithRate(int64(totalMaxCount)),
|
|
limiter.WithRequested(duration),
|
|
)
|
|
|
|
if err != nil {
|
|
fmt.Println("检查总请求数限制失败:", err.Error())
|
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
|
return
|
|
}
|
|
|
|
if !allowed {
|
|
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
|
}
|
|
|
|
c.Next()
|
|
|
|
if c.Writer.Status() < 400 {
|
|
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
|
|
}
|
|
}
|
|
}
|
|
|
|
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
|
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
|
|
|
|
return func(c *gin.Context) {
|
|
userId := strconv.Itoa(c.GetInt("id"))
|
|
totalKey := ModelRequestRateLimitCountMark + userId
|
|
successKey := ModelRequestRateLimitSuccessCountMark + userId
|
|
|
|
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
checkKey := successKey + "_check"
|
|
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
|
|
c.Status(http.StatusTooManyRequests)
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
|
|
if c.Writer.Status() < 400 {
|
|
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ModelRequestRateLimit() func(c *gin.Context) {
|
|
return func(c *gin.Context) {
|
|
if !setting.ModelRequestRateLimitEnabled {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
|
|
|
group := c.GetString("token_group")
|
|
if group == "" {
|
|
group = c.GetString("group")
|
|
}
|
|
if group == "" {
|
|
group = "default"
|
|
}
|
|
|
|
finalTotalCount := setting.ModelRequestRateLimitCount
|
|
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
|
|
foundGroupLimit := false
|
|
|
|
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
|
|
if found {
|
|
finalTotalCount = groupTotalCount
|
|
finalSuccessCount = groupSuccessCount
|
|
foundGroupLimit = true
|
|
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
|
}
|
|
|
|
if !foundGroupLimit {
|
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
|
}
|
|
|
|
if common.RedisEnabled {
|
|
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
|
} else {
|
|
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
|
}
|
|
}
|
|
}
|