feat: implement tiered billing expression evaluation and related functionality
- Added support for tiered billing expressions in the billing system. - Introduced new types and functions for handling billing expressions, including caching and execution. - Updated existing billing logic to accommodate tiered billing scenarios. - Enhanced request handling to support incoming billing expression requests. - Added tests for tiered billing functionality to ensure correctness.
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ResolveIncomingBillingExprRequestInput(c *gin.Context, info *relaycommon.RelayInfo) (billingexpr.RequestInput, error) {
|
||||
input := billingexpr.RequestInput{}
|
||||
if info != nil {
|
||||
input.Headers = cloneStringMap(info.RequestHeaders)
|
||||
}
|
||||
|
||||
bodyBytes, err := readIncomingBillingExprBody(c)
|
||||
if err != nil {
|
||||
return billingexpr.RequestInput{}, err
|
||||
}
|
||||
input.Body = bodyBytes
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func readIncomingBillingExprBody(c *gin.Context) ([]byte, error) {
|
||||
if c == nil || c.Request == nil || !isJSONContentType(c.Request.Header.Get("Content-Type")) {
|
||||
return nil, nil
|
||||
}
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.Bytes()
|
||||
}
|
||||
|
||||
func isJSONContentType(contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
return strings.HasPrefix(contentType, "application/json")
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return map[string]string{}
|
||||
}
|
||||
dst := make(map[string]string, len(src))
|
||||
for key, value := range src {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
body := []byte(`{"service_tier":"fast"}`)
|
||||
ctx.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
ctx.Set(common.KeyRequestBody, body)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RequestHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
}
|
||||
|
||||
input, err := ResolveIncomingBillingExprRequestInput(ctx, info)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, body, input.Body)
|
||||
require.Equal(t, "application/json", input.Headers["Content-Type"])
|
||||
}
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/pkg/billingexpr"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
@@ -50,6 +52,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
// Check if this model uses tiered_expr billing
|
||||
if billing_setting.GetBillingMode(info.OriginModelName) == billing_setting.BillingModeTieredExpr {
|
||||
return modelPriceHelperTiered(c, info, promptTokens, meta, groupRatioInfo)
|
||||
}
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -195,5 +202,73 @@ func ContainPriceOrRatio(modelName string) bool {
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
|
||||
_, ok = billing_setting.GetBillingExpr(modelName)
|
||||
return ok
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {
|
||||
exprStr, ok := billing_setting.GetBillingExpr(info.OriginModelName)
|
||||
if !ok {
|
||||
return types.PriceData{}, fmt.Errorf("model %s is configured as tiered_expr but has no billing expression", info.OriginModelName)
|
||||
}
|
||||
|
||||
estimatedCompletionTokens := 0
|
||||
if meta.MaxTokens != 0 {
|
||||
estimatedCompletionTokens = meta.MaxTokens
|
||||
}
|
||||
|
||||
requestInput, err := ResolveIncomingBillingExprRequestInput(c, info)
|
||||
if err != nil {
|
||||
return types.PriceData{}, err
|
||||
}
|
||||
|
||||
rawQuota, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{
|
||||
P: float64(promptTokens),
|
||||
C: float64(estimatedCompletionTokens),
|
||||
}, requestInput)
|
||||
if err != nil {
|
||||
return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err)
|
||||
}
|
||||
|
||||
preConsumedQuota := billingexpr.QuotaRound(rawQuota * groupRatioInfo.GroupRatio)
|
||||
|
||||
freeModel := false
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || rawQuota == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
|
||||
exprHash := billingexpr.ExprHashString(exprStr)
|
||||
snapshot := &billingexpr.BillingSnapshot{
|
||||
BillingMode: billing_setting.BillingModeTieredExpr,
|
||||
ModelName: info.OriginModelName,
|
||||
ExprString: exprStr,
|
||||
ExprHash: exprHash,
|
||||
GroupRatio: groupRatioInfo.GroupRatio,
|
||||
EstimatedPromptTokens: promptTokens,
|
||||
EstimatedCompletionTokens: estimatedCompletionTokens,
|
||||
EstimatedQuotaBeforeGroup: rawQuota,
|
||||
EstimatedQuotaAfterGroup: preConsumedQuota,
|
||||
EstimatedTier: trace.MatchedTier,
|
||||
}
|
||||
info.TieredBillingSnapshot = snapshot
|
||||
info.BillingRequestInput = &requestInput
|
||||
|
||||
priceData := types.PriceData{
|
||||
FreeModel: freeModel,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
QuotaToPreConsume: preConsumedQuota,
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("model_price_helper_tiered result: model=%s preConsume=%d rawQuota=%.2f groupRatio=%.2f tier=%s", info.OriginModelName, preConsumedQuota, rawQuota, groupRatioInfo.GroupRatio, trace.MatchedTier))
|
||||
}
|
||||
|
||||
info.PriceData = priceData
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user