feat: implement tiered billing expression evaluation and related functionality
- Added support for tiered billing expressions in the billing system. - Introduced new types and functions for handling billing expressions, including caching and execution. - Updated existing billing logic to accommodate tiered billing scenarios. - Enhanced request handling to support incoming billing expression requests. - Added tests for tiered billing functionality to ensure correctness.
This commit is contained in:
@@ -46,7 +46,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, ioReader)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -124,8 +125,10 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
var requestBody io.Reader = bytes.NewBuffer(jsonData)
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData))
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -18,4 +18,7 @@ type BillingSettler interface {
|
||||
|
||||
// GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。
|
||||
GetPreConsumedQuota() int
|
||||
|
||||
// Reserve 将预扣额度补到目标值;若目标值不高于当前预扣额度则不做任何事。
|
||||
Reserve(targetQuota int) error
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
@@ -152,6 +153,11 @@ type RelayInfo struct {
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules
|
||||
// captured at pre-consume time. Non-nil only when billing mode is "tiered_expr".
|
||||
TieredBillingSnapshot *billingexpr.BillingSnapshot
|
||||
BillingRequestInput *billingexpr.RequestInput
|
||||
|
||||
Request dto.Request
|
||||
|
||||
// RequestConversionChain records request format conversions in order, e.g.
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
@@ -236,6 +237,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
// Tiered billing early return
|
||||
if ok, tieredQuota, tieredResult := service.TryTieredSettle(relayInfo, billingexpr.TokenParams{
|
||||
P: float64(usage.PromptTokens),
|
||||
C: float64(usage.CompletionTokens),
|
||||
CR: float64(usage.PromptTokensDetails.CachedTokens),
|
||||
CC: float64(usage.PromptTokensDetails.CachedCreationTokens - usage.ClaudeCacheCreation1hTokens),
|
||||
CC1h: float64(usage.ClaudeCacheCreation1hTokens),
|
||||
}); ok {
|
||||
postConsumeQuotaTiered(ctx, relayInfo, usage, tieredQuota, tieredResult, extraContent...)
|
||||
return
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
@@ -506,3 +519,51 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
func postConsumeQuotaTiered(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, quota int, tieredResult *service.TieredResultWrapper, extraContent ...string) {
|
||||
_ = tieredResult // will be used for log enrichment
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
modelName := relayInfo.OriginModelName
|
||||
tokenName := ctx.GetString("token_name")
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
|
||||
totalTokens := usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
if totalTokens == 0 {
|
||||
quota = 0
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("tiered billing: total tokens is 0, userId %d, channelId %d, tokenId %d, model %s, pre-consumed %d",
|
||||
relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
if groupRatio != 0 && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
if err := service.SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling tiered billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
|
||||
other := service.GenerateTieredOtherInfo(ctx, relayInfo, tieredResult)
|
||||
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: logModel,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package relay
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -58,7 +59,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
var requestBody io.Reader = bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
|
||||
input := billingexpr.RequestInput{}
|
||||
if info != nil {
|
||||
input.Headers = cloneStringMap(info.RequestHeaders)
|
||||
}
|
||||
|
||||
bodyBytes, err := readIncomingBillingExprBody(c)
|
||||
if err != nil {
|
||||
return billingexpr.RequestInput{}, err
|
||||
}
|
||||
input.Body = bodyBytes
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
|
||||
if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
|
||||
return nil, nil
|
||||
}
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.Bytes()
|
||||
}
|
||||
|
||||
func isJSONContentType(contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
return strings.HasPrefix(contentType, "application/json")
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return map[string]string{}
|
||||
}
|
||||
dst := make(map[string]string, len(src))
|
||||
for key, value := range src {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
body := []byte(`{"service_tier":"fast"}`)
|
||||
ctx.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
ctx.Set(common.KeyRequestBody, body)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RequestHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
}
|
||||
|
||||
input, err := ResolveIncomingBillingExprRequestInput(ctx, info)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, body, input.Body)
|
||||
require.Equal(t, "application/json", input.Headers["Content-Type"])
|
||||
}
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
@@ -50,6 +52,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
// Check if this model uses tiered_expr billing
|
||||
if billing_setting.GetBillingMode(info.OriginModelName) == billing_setting.BillingModeTieredExpr {
|
||||
return modelPriceHelperTiered(c, info, promptTokens, meta, groupRatioInfo)
|
||||
}
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -195,5 +202,73 @@ func ContainPriceOrRatio(modelName string) bool {
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
|
||||
_, ok = billing_setting.GetBillingExpr(modelName)
|
||||
return ok
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
|
||||
exprStr, ok := billing_setting.GetBillingExpr(info.OriginModelName)
|
||||
if !ok {
|
||||
return types.PriceData{}, fmt.Errorf("model %s is configured as tiered_expr but has no billing expression", info.OriginModelName)
|
||||
}
|
||||
|
||||
estimatedCompletionTokens := 0
|
||||
if meta.MaxTokens != 0 {
|
||||
estimatedCompletionTokens = meta.MaxTokens
|
||||
}
|
||||
|
||||
requestInput, err := ResolveIncomingBillingExprRequestInput(c, info)
|
||||
if err != nil {
|
||||
return types.PriceData{}, err
|
||||
}
|
||||
|
||||
rawQuota, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{
|
||||
P: float64(promptTokens),
|
||||
C: float64(estimatedCompletionTokens),
|
||||
}, requestInput)
|
||||
if err != nil {
|
||||
return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err)
|
||||
}
|
||||
|
||||
preConsumedQuota := billingexpr.QuotaRound(rawQuota * groupRatioInfo.GroupRatio)
|
||||
|
||||
freeModel := false
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || rawQuota == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
|
||||
exprHash := billingexpr.ExprHashString(exprStr)
|
||||
snapshot := &billingexpr.BillingSnapshot{
|
||||
BillingMode: billing_setting.BillingModeTieredExpr,
|
||||
ModelName: info.OriginModelName,
|
||||
ExprString: exprStr,
|
||||
ExprHash: exprHash,
|
||||
GroupRatio: groupRatioInfo.GroupRatio,
|
||||
EstimatedPromptTokens: promptTokens,
|
||||
EstimatedCompletionTokens: estimatedCompletionTokens,
|
||||
EstimatedQuotaBeforeGroup: rawQuota,
|
||||
EstimatedQuotaAfterGroup: preConsumedQuota,
|
||||
EstimatedTier: trace.MatchedTier,
|
||||
}
|
||||
info.TieredBillingSnapshot = snapshot
|
||||
info.BillingRequestInput = &requestInput
|
||||
|
||||
priceData := types.PriceData{
|
||||
FreeModel: freeModel,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
QuotaToPreConsume: preConsumedQuota,
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("model_price_helper_tiered result: model=%s preConsume=%d rawQuota=%.2f groupRatio=%.2f tier=%s", info.OriginModelName, preConsumedQuota, rawQuota, groupRatioInfo.GroupRatio, trace.MatchedTier))
|
||||
}
|
||||
|
||||
info.PriceData = priceData
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user