refactor: centralize logging and update resource initialization
This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
This commit is contained in:
@@ -5,7 +5,7 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter {
|
|||||||
// 预加载脚本
|
// 预加载脚本
|
||||||
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
|
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
|
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
|
||||||
}
|
}
|
||||||
instance = &RedisLimiter{
|
instance = &RedisLimiter{
|
||||||
client: r,
|
client: r,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -486,8 +485,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
func AutomaticallyUpdateChannels(frequency int) {
|
func AutomaticallyUpdateChannels(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
logger.SysLog("updating all channels")
|
common.SysLog("updating all channels")
|
||||||
_ = updateAllChannelsBalance()
|
_ = updateAllChannelsBalance()
|
||||||
logger.SysLog("channels update done")
|
common.SysLog("channels update done")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
@@ -133,8 +132,17 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: newAPIError,
|
newAPIError: newAPIError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
request := buildTestRequest(testModel)
|
||||||
|
|
||||||
info := relaycommon.GenRelayInfo(c)
|
info, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, request, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info, nil)
|
err = helper.ModelMappedHelper(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,7 +152,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
request.Model = testModel
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
@@ -156,13 +166,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
|
||||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||||
logInfo := *info
|
logInfo := *info
|
||||||
logInfo.ApiKey = ""
|
logInfo.ApiKey = ""
|
||||||
logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
@@ -280,7 +289,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
Group: info.UsingGroup,
|
Group: info.UsingGroup,
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: nil,
|
localErr: nil,
|
||||||
@@ -462,13 +471,13 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
if frequency <= 0 {
|
if frequency <= 0 {
|
||||||
logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
logger.SysLog("testing all channels")
|
common.SysLog("testing all channels")
|
||||||
_ = testAllChannels(false)
|
_ = testAllChannels(false)
|
||||||
logger.SysLog("channel test finished")
|
common.SysLog("channel test finished")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||||
@@ -98,6 +99,6 @@ func MigrateConsoleSetting(c *gin.Context) {
|
|||||||
|
|
||||||
// 重新加载 OptionMap
|
// 重新加载 OptionMap
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
logger.SysLog("console setting migrated")
|
common.SysLog("console setting migrated")
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -48,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@@ -64,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||||
res2, err := client.Do(req)
|
res2, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
|
|||||||
+3
-1
@@ -93,7 +93,9 @@ func init() {
|
|||||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ChannelType: i,
|
||||||
|
}}
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
|
|||||||
+5
-6
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
@@ -59,7 +58,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@@ -70,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oidcResponse.AccessToken == "" {
|
if oidcResponse.AccessToken == "" {
|
||||||
logger.SysError("OIDC 获取 Token 失败,请检查设置!")
|
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,12 +80,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||||
res2, err := client.Do(req)
|
res2, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
if res2.StatusCode != http.StatusOK {
|
if res2.StatusCode != http.StatusOK {
|
||||||
logger.SysError("OIDC 获取用户信息失败!请检查设置!")
|
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||||
logger.SysError("OIDC 获取用户信息为空!请检查设置!")
|
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||||
}
|
}
|
||||||
return &oidcUser, nil
|
return &oidcUser, nil
|
||||||
|
|||||||
@@ -56,5 +56,5 @@ func Playground(c *gin.Context) {
|
|||||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
Relay(c)
|
Relay(c, types.RelayFormatOpenAI)
|
||||||
}
|
}
|
||||||
|
|||||||
+26
-36
@@ -104,26 +104,6 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//includeUsage := true
|
|
||||||
//// 判断用户是否需要返回使用情况
|
|
||||||
//if textRequest.StreamOptions != nil {
|
|
||||||
// includeUsage = textRequest.StreamOptions.IncludeUsage
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 如果不支持StreamOptions,将StreamOptions设置为nil
|
|
||||||
//if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
|
||||||
// textRequest.StreamOptions = nil
|
|
||||||
//} else {
|
|
||||||
// // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
|
||||||
// if constant.ForceStreamOption {
|
|
||||||
// textRequest.StreamOptions = &dto.StreamOptions{
|
|
||||||
// IncludeUsage: true,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//relayInfo.ShouldIncludeUsage = includeUsage
|
|
||||||
|
|
||||||
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||||||
@@ -178,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
|
|
||||||
switch relayFormat {
|
switch relayFormat {
|
||||||
case types.RelayFormatOpenAIRealtime:
|
case types.RelayFormatOpenAIRealtime:
|
||||||
newAPIError = relay.WssHelper(c, ws)
|
newAPIError = relay.WssHelper(c, relayInfo)
|
||||||
case types.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||||||
case types.RelayFormatGemini:
|
case types.RelayFormatGemini:
|
||||||
@@ -324,35 +304,45 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := c.GetInt("relay_mode")
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||||||
var err *dto.MidjourneyResponse
|
|
||||||
switch relayMode {
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||||||
|
"type": "upstream_error",
|
||||||
|
"code": 4,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var mjErr *dto.MidjourneyResponse
|
||||||
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relay.RelayMidjourneyNotify(c)
|
mjErr = relay.RelayMidjourneyNotify(c)
|
||||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
err = relay.RelaySwapFace(c)
|
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(mjErr)
|
||||||
if err != nil {
|
if mjErr != nil {
|
||||||
statusCode := http.StatusBadRequest
|
statusCode := http.StatusBadRequest
|
||||||
if err.Code == 30 {
|
if mjErr.Code == 30 {
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
statusCode = http.StatusTooManyRequests
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
}
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"code": err.Code,
|
"code": mjErr.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+9
-9
@@ -26,7 +26,7 @@ func UpdateTaskBulk() {
|
|||||||
//imageModel := "midjourney"
|
//imageModel := "midjourney"
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(15) * time.Second)
|
time.Sleep(time.Duration(15) * time.Second)
|
||||||
logger.SysLog("任务进度轮询开始")
|
common.SysLog("任务进度轮询开始")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||||
@@ -66,7 +66,7 @@ func UpdateTaskBulk() {
|
|||||||
|
|
||||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||||
}
|
}
|
||||||
logger.SysLog("任务进度轮询完成")
|
common.SysLog("任务进度轮询完成")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
default:
|
default:
|
||||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||||
logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -100,14 +100,14 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
}
|
}
|
||||||
channel, err := model.CacheGetChannel(channelId)
|
channel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -119,7 +119,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"ids": taskIds,
|
"ids": taskIds,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
@@ -129,7 +129,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||||
@@ -139,7 +139,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !responseItems.IsSuccess() {
|
if !responseItems.IsSuccess() {
|
||||||
logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
|
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
@@ -37,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if errUpdate != nil {
|
if errUpdate != nil {
|
||||||
logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
}
|
}
|
||||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -112,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
task.StartTime = now
|
task.StartTime = now
|
||||||
}
|
}
|
||||||
case model.TaskStatusSuccess:
|
case model.TaskStatusSuccess:
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
if task.FinishTime == 0 {
|
if task.FinishTime == 0 {
|
||||||
task.FinishTime = now
|
task.FinishTime = now
|
||||||
}
|
}
|
||||||
@@ -140,7 +141,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
task.Progress = taskResult.Progress
|
task.Progress = taskResult.Progress
|
||||||
}
|
}
|
||||||
if err := task.Update(); err != nil {
|
if err := task.Update(); err != nil {
|
||||||
logger.SysError("UpdateVideoTask task error: " + err.Error())
|
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
+1
-2
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -103,7 +102,7 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成令牌失败",
|
"message": "生成令牌失败",
|
||||||
})
|
})
|
||||||
logger.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
|
|||||||
+6
-7
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -71,7 +70,7 @@ func Setup2FA(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成2FA密钥失败",
|
"message": "生成2FA密钥失败",
|
||||||
})
|
})
|
||||||
logger.SysError("生成TOTP密钥失败: " + err.Error())
|
common.SysLog("生成TOTP密钥失败: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +81,7 @@ func Setup2FA(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成备用码失败",
|
"message": "生成备用码失败",
|
||||||
})
|
})
|
||||||
logger.SysError("生成备用码失败: " + err.Error())
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +115,7 @@ func Setup2FA(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "保存备用码失败",
|
"message": "保存备用码失败",
|
||||||
})
|
})
|
||||||
logger.SysError("保存备用码失败: " + err.Error())
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,7 +294,7 @@ func Get2FAStatus(c *gin.Context) {
|
|||||||
// 获取剩余备用码数量
|
// 获取剩余备用码数量
|
||||||
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("获取备用码数量失败: " + err.Error())
|
common.SysLog("获取备用码数量失败: " + err.Error())
|
||||||
} else {
|
} else {
|
||||||
status["backup_codes_remaining"] = backupCount
|
status["backup_codes_remaining"] = backupCount
|
||||||
}
|
}
|
||||||
@@ -369,7 +368,7 @@ func RegenerateBackupCodes(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成备用码失败",
|
"message": "生成备用码失败",
|
||||||
})
|
})
|
||||||
logger.SysError("生成备用码失败: " + err.Error())
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,7 +378,7 @@ func RegenerateBackupCodes(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "保存备用码失败",
|
"message": "保存备用码失败",
|
||||||
})
|
})
|
||||||
logger.SysError("保存备用码失败: " + err.Error())
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+3
-3
@@ -193,7 +193,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "数据库错误,请稍后重试",
|
"message": "数据库错误,请稍后重试",
|
||||||
})
|
})
|
||||||
logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if exist {
|
if exist {
|
||||||
@@ -236,7 +236,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成默认令牌失败",
|
"message": "生成默认令牌失败",
|
||||||
})
|
})
|
||||||
logger.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 生成默认令牌
|
// 生成默认令牌
|
||||||
@@ -343,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成失败",
|
"message": "生成失败",
|
||||||
})
|
})
|
||||||
logger.SysError("failed to generate key: " + err.Error())
|
common.SysLog("failed to generate key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.SetAccessToken(key)
|
user.SetAccessToken(key)
|
||||||
|
|||||||
@@ -332,9 +332,9 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type MessageImageUrl struct {
|
type MessageImageUrl struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
Detail string `json:"detail"`
|
Detail string `json:"detail"`
|
||||||
//MimeType string
|
MimeType string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MessageImageUrl) IsRemoteImage() bool {
|
func (m *MessageImageUrl) IsRemoteImage() bool {
|
||||||
|
|||||||
@@ -9,3 +9,16 @@ type Request interface {
|
|||||||
GetTokenCountMeta() *types.TokenCountMeta
|
GetTokenCountMeta() *types.TokenCountMeta
|
||||||
IsStream(c *gin.Context) bool
|
IsStream(c *gin.Context) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BaseRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,22 +36,22 @@ func main() {
|
|||||||
|
|
||||||
err := InitResources()
|
err := InitResources()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to initialize resources: " + err.Error())
|
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.SysLog("New API " + common.Version + " started")
|
common.SysLog("New API " + common.Version + " started")
|
||||||
if os.Getenv("GIN_MODE") != "debug" {
|
if os.Getenv("GIN_MODE") != "debug" {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
logger.SysLog("running in debug mode")
|
common.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := model.CloseDB()
|
err := model.CloseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to close database: " + err.Error())
|
common.FatalLog("failed to close database: " + err.Error())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -60,18 +60,18 @@ func main() {
|
|||||||
common.MemoryCacheEnabled = true
|
common.MemoryCacheEnabled = true
|
||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
logger.SysLog("memory cache enabled")
|
common.SysLog("memory cache enabled")
|
||||||
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
|
|
||||||
// Add panic recovery and retry for InitChannelCache
|
// Add panic recovery and retry for InitChannelCache
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||||
// Retry once
|
// Retry once
|
||||||
_, _, fixErr := model.FixAbility()
|
_, _, fixErr := model.FixAbility()
|
||||||
if fixErr != nil {
|
if fixErr != nil {
|
||||||
logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -90,14 +90,14 @@ func main() {
|
|||||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
|
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
|
||||||
}
|
}
|
||||||
go controller.AutomaticallyUpdateChannels(frequency)
|
go controller.AutomaticallyUpdateChannels(frequency)
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
model.InitBatchUpdater()
|
model.InitBatchUpdater()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,13 +120,13 @@ func main() {
|
|||||||
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
|
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
|
||||||
})
|
})
|
||||||
go common.Monitor()
|
go common.Monitor()
|
||||||
logger.SysLog("pprof enabled")
|
common.SysLog("pprof enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
||||||
logger.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
@@ -156,7 +156,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
err = server.Run(":" + port)
|
err = server.Run(":" + port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to start HTTP server: " + err.Error())
|
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,8 +165,8 @@ func InitResources() error {
|
|||||||
// This is a placeholder function for future resource initialization
|
// This is a placeholder function for future resource initialization
|
||||||
err := godotenv.Load(".env")
|
err := godotenv.Load(".env")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||||
logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载环境变量
|
// 加载环境变量
|
||||||
@@ -184,7 +184,7 @@ func InitResources() error {
|
|||||||
// Initialize SQL Database
|
// Initialize SQL Database
|
||||||
err = model.InitDB()
|
err = model.InitDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to initialize database: " + err.Error())
|
common.FatalLog("failed to initialize database: " + err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
logger.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type turnstileCheckResponse struct {
|
type turnstileCheckResponse struct {
|
||||||
@@ -38,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
"remoteip": {c.ClientIP()},
|
"remoteip": {c.ClientIP()},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@@ -50,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
var res turnstileCheckResponse
|
var res turnstileCheckResponse
|
||||||
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
+4
-5
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -295,13 +294,13 @@ func FixAbility() (int, int, error) {
|
|||||||
if common.UsingSQLite {
|
if common.UsingSQLite {
|
||||||
err := DB.Exec("DELETE FROM abilities").Error
|
err := DB.Exec("DELETE FROM abilities").Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -321,7 +320,7 @@ func FixAbility() (int, int, error) {
|
|||||||
// Delete all abilities of this channel
|
// Delete all abilities of this channel
|
||||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||||
failCount += len(chunk)
|
failCount += len(chunk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -329,7 +328,7 @@ func FixAbility() (int, int, error) {
|
|||||||
for _, channel := range chunk {
|
for _, channel := range chunk {
|
||||||
err = channel.AddAbilities(nil)
|
err = channel.AddAbilities(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||||
failCount++
|
failCount++
|
||||||
} else {
|
} else {
|
||||||
successCount++
|
successCount++
|
||||||
|
|||||||
+13
-14
@@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -210,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
|||||||
if channel.OtherInfo != "" {
|
if channel.OtherInfo != "" {
|
||||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal other info: " + err.Error())
|
common.SysLog("failed to unmarshal other info: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return otherInfo
|
return otherInfo
|
||||||
@@ -219,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
|||||||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
||||||
otherInfoBytes, err := json.Marshal(otherInfo)
|
otherInfoBytes, err := json.Marshal(otherInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to marshal other info: " + err.Error())
|
common.SysLog("failed to marshal other info: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.OtherInfo = string(otherInfoBytes)
|
channel.OtherInfo = string(otherInfoBytes)
|
||||||
@@ -489,7 +488,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
ResponseTime: int(responseTime),
|
ResponseTime: int(responseTime),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update response time: " + err.Error())
|
common.SysLog("failed to update response time: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -499,7 +498,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
|
|||||||
Balance: balance,
|
Balance: balance,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update balance: " + err.Error())
|
common.SysLog("failed to update balance: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -615,7 +614,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
if shouldUpdateAbilities {
|
if shouldUpdateAbilities {
|
||||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update ability status: " + err.Error())
|
common.SysLog("failed to update ability status: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -643,7 +642,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
}
|
}
|
||||||
err = channel.Save()
|
err = channel.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update channel status: " + err.Error())
|
common.SysLog("failed to update channel status: " + err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -705,7 +704,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
err = channel.UpdateAbilities(nil)
|
err = channel.UpdateAbilities(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update abilities: " + err.Error())
|
common.SysLog("failed to update abilities: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -729,7 +728,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
|||||||
func updateChannelUsedQuota(id int, quota int) {
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update channel used quota: " + err.Error())
|
common.SysLog("failed to update channel used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -822,7 +821,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
|||||||
if channel.Setting != nil && *channel.Setting != "" {
|
if channel.Setting != nil && *channel.Setting != "" {
|
||||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||||||
channel.Setting = nil // 清空设置以避免后续错误
|
channel.Setting = nil // 清空设置以避免后续错误
|
||||||
_ = channel.Save() // 保存修改
|
_ = channel.Save() // 保存修改
|
||||||
}
|
}
|
||||||
@@ -833,7 +832,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
|||||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||||
settingBytes, err := common.Marshal(setting)
|
settingBytes, err := common.Marshal(setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to marshal setting: " + err.Error())
|
common.SysLog("failed to marshal setting: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||||||
@@ -844,7 +843,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
|||||||
if channel.OtherSettings != "" {
|
if channel.OtherSettings != "" {
|
||||||
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||||||
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
||||||
_ = channel.Save() // 保存修改
|
_ = channel.Save() // 保存修改
|
||||||
}
|
}
|
||||||
@@ -855,7 +854,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
|||||||
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
||||||
settingBytes, err := common.Marshal(setting)
|
settingBytes, err := common.Marshal(setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to marshal setting: " + err.Error())
|
common.SysLog("failed to marshal setting: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.OtherSettings = string(settingBytes)
|
channel.OtherSettings = string(settingBytes)
|
||||||
@@ -866,7 +865,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
|
|||||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal param override: " + err.Error())
|
common.SysLog("failed to unmarshal param override: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return paramOverride
|
return paramOverride
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -85,13 +84,13 @@ func InitChannelCache() {
|
|||||||
}
|
}
|
||||||
channelsIDM = newChannelId2channel
|
channelsIDM = newChannelId2channel
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
logger.SysLog("channels synced from database")
|
common.SysLog("channels synced from database")
|
||||||
}
|
}
|
||||||
|
|
||||||
func SyncChannelCache(frequency int) {
|
func SyncChannelCache(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
logger.SysLog("syncing channels from database")
|
common.SysLog("syncing channels from database")
|
||||||
InitChannelCache()
|
InitChannelCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -88,7 +88,7 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to record log: " + err.Error())
|
common.SysLog("failed to record log: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+15
-16
@@ -5,7 +5,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/logger"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -85,7 +84,7 @@ func createRootAccountIfNeed() error {
|
|||||||
var user User
|
var user User
|
||||||
//if user.Status != common.UserStatusEnabled {
|
//if user.Status != common.UserStatusEnabled {
|
||||||
if err := DB.First(&user).Error; err != nil {
|
if err := DB.First(&user).Error; err != nil {
|
||||||
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||||
hashedPassword, err := common.Password2Hash("123456")
|
hashedPassword, err := common.Password2Hash("123456")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -109,7 +108,7 @@ func CheckSetup() {
|
|||||||
if setup == nil {
|
if setup == nil {
|
||||||
// No setup record exists, check if we have a root user
|
// No setup record exists, check if we have a root user
|
||||||
if RootUserExists() {
|
if RootUserExists() {
|
||||||
logger.SysLog("system is not initialized, but root user exists")
|
common.SysLog("system is not initialized, but root user exists")
|
||||||
// Create setup record
|
// Create setup record
|
||||||
newSetup := Setup{
|
newSetup := Setup{
|
||||||
Version: common.Version,
|
Version: common.Version,
|
||||||
@@ -117,16 +116,16 @@ func CheckSetup() {
|
|||||||
}
|
}
|
||||||
err := DB.Create(&newSetup).Error
|
err := DB.Create(&newSetup).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog("failed to create setup record: " + err.Error())
|
common.SysLog("failed to create setup record: " + err.Error())
|
||||||
}
|
}
|
||||||
constant.Setup = true
|
constant.Setup = true
|
||||||
} else {
|
} else {
|
||||||
logger.SysLog("system is not initialized and no root user exists")
|
common.SysLog("system is not initialized and no root user exists")
|
||||||
constant.Setup = false
|
constant.Setup = false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Setup record exists, system is initialized
|
// Setup record exists, system is initialized
|
||||||
logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
||||||
constant.Setup = true
|
constant.Setup = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,7 +138,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|||||||
if dsn != "" {
|
if dsn != "" {
|
||||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
logger.SysLog("using PostgreSQL as database")
|
common.SysLog("using PostgreSQL as database")
|
||||||
if !isLog {
|
if !isLog {
|
||||||
common.UsingPostgreSQL = true
|
common.UsingPostgreSQL = true
|
||||||
} else {
|
} else {
|
||||||
@@ -153,7 +152,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(dsn, "local") {
|
if strings.HasPrefix(dsn, "local") {
|
||||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
if !isLog {
|
if !isLog {
|
||||||
common.UsingSQLite = true
|
common.UsingSQLite = true
|
||||||
} else {
|
} else {
|
||||||
@@ -164,7 +163,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use MySQL
|
// Use MySQL
|
||||||
logger.SysLog("using MySQL as database")
|
common.SysLog("using MySQL as database")
|
||||||
// check parseTime
|
// check parseTime
|
||||||
if !strings.Contains(dsn, "parseTime") {
|
if !strings.Contains(dsn, "parseTime") {
|
||||||
if strings.Contains(dsn, "?") {
|
if strings.Contains(dsn, "?") {
|
||||||
@@ -183,7 +182,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Use SQLite
|
// Use SQLite
|
||||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
common.UsingSQLite = true
|
common.UsingSQLite = true
|
||||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
@@ -217,11 +216,11 @@ func InitDB() (err error) {
|
|||||||
if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||||
}
|
}
|
||||||
logger.SysLog("database migration started")
|
common.SysLog("database migration started")
|
||||||
err = migrateDB()
|
err = migrateDB()
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
logger.FatalLog(err)
|
common.FatalLog(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -254,11 +253,11 @@ func InitLogDB() (err error) {
|
|||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
logger.SysLog("database migration started")
|
common.SysLog("database migration started")
|
||||||
err = migrateLOGDB()
|
err = migrateLOGDB()
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
logger.FatalLog(err)
|
common.FatalLog(err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -355,7 +354,7 @@ func migrateDBFast() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.SysLog("database migrated")
|
common.SysLog("database migrated")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,6 +503,6 @@ func PingDB() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lastPingTime = time.Now()
|
lastPingTime = time.Now()
|
||||||
logger.SysLog("Database pinged successfully")
|
common.SysLog("Database pinged successfully")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-3
@@ -2,7 +2,6 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/config"
|
"one-api/setting/config"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
@@ -151,7 +150,7 @@ func loadOptionsFromDatabase() {
|
|||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
err := updateOptionMap(option.Key, option.Value)
|
err := updateOptionMap(option.Key, option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update option map: " + err.Error())
|
common.SysLog("failed to update option map: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,7 +158,7 @@ func loadOptionsFromDatabase() {
|
|||||||
func SyncOptions(frequency int) {
|
func SyncOptions(frequency int) {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
logger.SysLog("syncing options from database")
|
common.SysLog("syncing options from database")
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-2
@@ -3,7 +3,6 @@ package model
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/logger"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -93,7 +92,7 @@ func updatePricing() {
|
|||||||
//modelRatios := common.GetModelRatios()
|
//modelRatios := common.GetModelRatios()
|
||||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 预加载模型元数据与供应商一次,避免循环查询
|
// 预加载模型元数据与供应商一次,避免循环查询
|
||||||
|
|||||||
+9
-10
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
@@ -92,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExpired
|
token.Status = common.TokenStatusExpired
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update token status" + err.Error())
|
common.SysLog("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return token, errors.New("该令牌已过期")
|
return token, errors.New("该令牌已过期")
|
||||||
@@ -103,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExhausted
|
token.Status = common.TokenStatusExhausted
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update token status" + err.Error())
|
common.SysLog("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
keyPrefix := key[:3]
|
keyPrefix := key[:3]
|
||||||
@@ -135,7 +134,7 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
if shouldUpdateRedis(true, err) {
|
if shouldUpdateRedis(true, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := cacheSetToken(token); err != nil {
|
if err := cacheSetToken(token); err != nil {
|
||||||
logger.SysError("failed to update user status cache: " + err.Error())
|
common.SysLog("failed to update user status cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -148,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := cacheSetToken(*token); err != nil {
|
if err := cacheSetToken(*token); err != nil {
|
||||||
logger.SysError("failed to update user status cache: " + err.Error())
|
common.SysLog("failed to update user status cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -179,7 +178,7 @@ func (token *Token) Update() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheSetToken(*token)
|
err := cacheSetToken(*token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update token cache: " + err.Error())
|
common.SysLog("failed to update token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -195,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheSetToken(*token)
|
err := cacheSetToken(*token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update token cache: " + err.Error())
|
common.SysLog("failed to update token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -210,7 +209,7 @@ func (token *Token) Delete() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDeleteToken(token.Key)
|
err := cacheDeleteToken(token.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to delete token cache: " + err.Error())
|
common.SysLog("failed to delete token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -270,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheIncrTokenQuota(key, int64(quota))
|
err := cacheIncrTokenQuota(key, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to increase token quota: " + err.Error())
|
common.SysLog("failed to increase token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -300,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDecrTokenQuota(key, int64(quota))
|
err := cacheDecrTokenQuota(key, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to decrease token quota: " + err.Error())
|
common.SysLog("failed to decrease token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-5
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -244,7 +243,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
|||||||
if !common.ValidateTOTPCode(t.Secret, code) {
|
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||||
// 增加失败次数
|
// 增加失败次数
|
||||||
if err := t.IncrementFailedAttempts(); err != nil {
|
if err := t.IncrementFailedAttempts(); err != nil {
|
||||||
logger.SysError("更新2FA失败次数失败: " + err.Error())
|
common.SysLog("更新2FA失败次数失败: " + err.Error())
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@@ -256,7 +255,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
|||||||
t.LastUsedAt = &now
|
t.LastUsedAt = &now
|
||||||
|
|
||||||
if err := t.Update(); err != nil {
|
if err := t.Update(); err != nil {
|
||||||
logger.SysError("更新2FA使用记录失败: " + err.Error())
|
common.SysLog("更新2FA使用记录失败: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -278,7 +277,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
|||||||
if !valid {
|
if !valid {
|
||||||
// 增加失败次数
|
// 增加失败次数
|
||||||
if err := t.IncrementFailedAttempts(); err != nil {
|
if err := t.IncrementFailedAttempts(); err != nil {
|
||||||
logger.SysError("更新2FA失败次数失败: " + err.Error())
|
common.SysLog("更新2FA失败次数失败: " + err.Error())
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@@ -290,7 +289,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
|||||||
t.LastUsedAt = &now
|
t.LastUsedAt = &now
|
||||||
|
|
||||||
if err := t.Update(); err != nil {
|
if err := t.Update(); err != nil {
|
||||||
logger.SysError("更新2FA使用记录失败: " + err.Error())
|
common.SysLog("更新2FA使用记录失败: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
+4
-5
@@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -25,12 +24,12 @@ func UpdateQuotaData() {
|
|||||||
// recover
|
// recover
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
if common.DataExportEnabled {
|
if common.DataExportEnabled {
|
||||||
logger.SysLog("正在更新数据看板数据...")
|
common.SysLog("正在更新数据看板数据...")
|
||||||
SaveQuotaDataCache()
|
SaveQuotaDataCache()
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
|
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
|
||||||
@@ -92,7 +91,7 @@ func SaveQuotaDataCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
CacheQuotaData = make(map[string]*QuotaData)
|
CacheQuotaData = make(map[string]*QuotaData)
|
||||||
logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
|
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
|
||||||
}
|
}
|
||||||
|
|
||||||
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
|
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
|
||||||
@@ -103,7 +102,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int,
|
|||||||
"token_used": gorm.Expr("token_used + ?", tokenUsed),
|
"token_used": gorm.Expr("token_used + ?", tokenUsed),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
|
common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+12
-12
@@ -76,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
|
|||||||
if user.Setting != "" {
|
if user.Setting != "" {
|
||||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return setting
|
return setting
|
||||||
@@ -85,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
|
|||||||
func (user *User) SetSetting(setting dto.UserSetting) {
|
func (user *User) SetSetting(setting dto.UserSetting) {
|
||||||
settingBytes, err := json.Marshal(setting)
|
settingBytes, err := json.Marshal(setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to marshal setting: " + err.Error())
|
common.SysLog("failed to marshal setting: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Setting = string(settingBytes)
|
user.Setting = string(settingBytes)
|
||||||
@@ -518,7 +518,7 @@ func IsAdmin(userId int) bool {
|
|||||||
var user User
|
var user User
|
||||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("no such user " + err.Error())
|
common.SysLog("no such user " + err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= common.RoleAdminUser
|
||||||
@@ -573,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserQuotaCache(id, quota); err != nil {
|
if err := updateUserQuotaCache(id, quota); err != nil {
|
||||||
logger.SysError("failed to update user quota cache: " + err.Error())
|
common.SysLog("failed to update user quota cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -611,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserGroupCache(id, group); err != nil {
|
if err := updateUserGroupCache(id, group); err != nil {
|
||||||
logger.SysError("failed to update user group cache: " + err.Error())
|
common.SysLog("failed to update user group cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -640,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
|
|||||||
if shouldUpdateRedis(fromDB, err) {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserSettingCache(id, setting); err != nil {
|
if err := updateUserSettingCache(id, setting); err != nil {
|
||||||
logger.SysError("failed to update user setting cache: " + err.Error())
|
common.SysLog("failed to update user setting cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -670,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheIncrUserQuota(id, int64(quota))
|
err := cacheIncrUserQuota(id, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to increase user quota: " + err.Error())
|
common.SysLog("failed to increase user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if !db && common.BatchUpdateEnabled {
|
if !db && common.BatchUpdateEnabled {
|
||||||
@@ -695,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDecrUserQuota(id, int64(quota))
|
err := cacheDecrUserQuota(id, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to decrease user quota: " + err.Error())
|
common.SysLog("failed to decrease user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if common.BatchUpdateEnabled {
|
if common.BatchUpdateEnabled {
|
||||||
@@ -751,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
|||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
common.SysLog("failed to update user used quota and request count: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -768,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
|
|||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update user used quota: " + err.Error())
|
common.SysLog("failed to update user used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserRequestCount(id int, count int) {
|
func updateUserRequestCount(id int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update user request count: " + err.Error())
|
common.SysLog("failed to update user request count: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -786,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserNameCache(id, username); err != nil {
|
if err := updateUserNameCache(id, username); err != nil {
|
||||||
logger.SysError("failed to update user name cache: " + err.Error())
|
common.SysLog("failed to update user name cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-3
@@ -5,7 +5,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -38,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
|
|||||||
if user.Setting != "" {
|
if user.Setting != "" {
|
||||||
err := common.Unmarshal([]byte(user.Setting), &setting)
|
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return setting
|
return setting
|
||||||
@@ -79,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) && user != nil {
|
if shouldUpdateRedis(fromDB, err) && user != nil {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserCache(*user); err != nil {
|
if err := updateUserCache(*user); err != nil {
|
||||||
logger.SysError("failed to update user status cache: " + err.Error())
|
common.SysLog("failed to update user status cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-5
@@ -3,7 +3,6 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -66,7 +65,7 @@ func batchUpdate() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.SysLog("batch update started")
|
common.SysLog("batch update started")
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateLocks[i].Lock()
|
batchUpdateLocks[i].Lock()
|
||||||
store := batchUpdateStores[i]
|
store := batchUpdateStores[i]
|
||||||
@@ -78,12 +77,12 @@ func batchUpdate() {
|
|||||||
case BatchUpdateTypeUserQuota:
|
case BatchUpdateTypeUserQuota:
|
||||||
err := increaseUserQuota(key, value)
|
err := increaseUserQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to batch update user quota: " + err.Error())
|
common.SysLog("failed to batch update user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeTokenQuota:
|
case BatchUpdateTypeTokenQuota:
|
||||||
err := increaseTokenQuota(key, value)
|
err := increaseTokenQuota(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to batch update token quota: " + err.Error())
|
common.SysLog("failed to batch update token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeUsedQuota:
|
case BatchUpdateTypeUsedQuota:
|
||||||
updateUserUsedQuota(key, value)
|
updateUserUsedQuota(key, value)
|
||||||
@@ -94,7 +93,7 @@ func batchUpdate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.SysLog("batch update finished")
|
common.SysLog("batch update finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordExist(err error) (bool, error) {
|
func RecordExist(err error) (bool, error) {
|
||||||
|
|||||||
@@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
var fullRequestURL string
|
var fullRequestURL string
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
|
||||||
default:
|
default:
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||||
default:
|
default:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
|||||||
imageRequest.Input.Prompt = request.Prompt
|
imageRequest.Input.Prompt = request.Prompt
|
||||||
imageRequest.Model = request.Model
|
imageRequest.Model = request.Model
|
||||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||||
imageRequest.Parameters.N = request.N
|
imageRequest.Parameters.N = int(request.N)
|
||||||
imageRequest.ResponseFormat = request.ResponseFormat
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
|
|
||||||
return &imageRequest
|
return &imageRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
|
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
||||||
|
|
||||||
var aliResponse AliResponse
|
var aliResponse AliResponse
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("updateTask client.Do err: " + err.Error())
|
common.SysLog("updateTask client.Do err: " + err.Error())
|
||||||
return &aliResponse, err, nil
|
return &aliResponse, err, nil
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|||||||
var response AliResponse
|
var response AliResponse
|
||||||
err = json.Unmarshal(responseBody, &response)
|
err = json.Unmarshal(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("updateTask NewDecoder err: " + err.Error())
|
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
||||||
return &aliResponse, err, nil
|
return &aliResponse, err, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -150,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
|||||||
var aliResponse AliResponse
|
var aliResponse AliResponse
|
||||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if aliResponse.Usage.OutputTokens != 0 {
|
if aliResponse.Usage.OutputTokens != 0 {
|
||||||
@@ -163,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
|||||||
lastResponseText = aliResponse.Output.Text
|
lastResponseText = aliResponse.Output.Text
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
default:
|
default:
|
||||||
suffix += strings.ToLower(info.UpstreamModelName)
|
suffix += strings.ToLower(info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -119,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
var baiduResponse BaiduChatStreamResponse
|
var baiduResponse BaiduChatStreamResponse
|
||||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if baiduResponse.Usage.TotalTokens != 0 {
|
if baiduResponse.Usage.TotalTokens != 0 {
|
||||||
@@ -130,7 +129,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
err = helper.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error sending stream response: " + err.Error())
|
common.SysLog("error sending stream response: " + err.Error())
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesEdits:
|
||||||
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if a.RequestMode == RequestModeMessage {
|
if a.RequestMode == RequestModeMessage {
|
||||||
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
for _, toolCall := range message.ParseToolCalls() {
|
for _, toolCall := range message.ParseToolCalls() {
|
||||||
inputObj := make(map[string]any)
|
inputObj := make(map[string]any)
|
||||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||||
logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||||
@@ -610,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == types.RelayFormatClaude {
|
||||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
@@ -629,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
|
|
||||||
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
||||||
@@ -654,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
|||||||
}
|
}
|
||||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
logger.SysError("claude response usage is not complete, maybe upstream error")
|
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||||
}
|
}
|
||||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == types.RelayFormatClaude {
|
||||||
//
|
//
|
||||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||||
|
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("send final response failed: " + err.Error())
|
common.SysLog("send final response failed: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
@@ -722,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
var responseData []byte
|
var responseData []byte
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
openaiResponse.Usage = *claudeInfo.Usage
|
openaiResponse.Usage = *claudeInfo.Usage
|
||||||
responseData, err = json.Marshal(openaiResponse)
|
responseData, err = json.Marshal(openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
responseData = data
|
responseData = data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||||
case constant.RelayModeResponses:
|
case constant.RelayModeResponses:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -119,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
var cohereResp CohereResponse
|
var cohereResp CohereResponse
|
||||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
var openaiResp dto.ChatCompletionsStreamResponse
|
var openaiResp dto.ChatCompletionsStreamResponse
|
||||||
@@ -154,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
}
|
}
|
||||||
jsonStr, err := json.Marshal(openaiResp)
|
jsonStr, err := json.Marshal(openaiResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string {
|
|||||||
|
|
||||||
// GetRequestURL implements channel.Adaptor.
|
// GetRequestURL implements channel.Adaptor.
|
||||||
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init implements channel.Adaptor.
|
// Init implements channel.Adaptor.
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -155,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
|||||||
var chatData CozeChatResponseData
|
var chatData CozeChatResponseData
|
||||||
err := json.Unmarshal([]byte(data), &chatData)
|
err := json.Unmarshal([]byte(data), &chatData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
|
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
|||||||
var messageData CozeChatV3MessageDetail
|
var messageData CozeChatV3MessageDetail
|
||||||
err := json.Unmarshal([]byte(data), &messageData)
|
err := json.Unmarshal([]byte(data), &messageData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
|
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var content string
|
var content string
|
||||||
err = json.Unmarshal(messageData.Content, &content)
|
err = json.Unmarshal(messageData.Content, &content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
|
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
|||||||
var errorData CozeError
|
var errorData CozeError
|
||||||
err := json.Unmarshal([]byte(data), &errorData)
|
err := json.Unmarshal([]byte(data), &errorData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
|
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
||||||
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
|
||||||
|
|
||||||
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
||||||
// 将 conversationId和chatId作为参数发送get请求
|
// 将 conversationId和chatId作为参数发送get请求
|
||||||
@@ -259,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
|
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
|
||||||
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
|
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
|
||||||
|
|
||||||
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
||||||
req, err := http.NewRequest("GET", requestURL, nil)
|
req, err := http.NewRequest("GET", requestURL, nil)
|
||||||
|
|||||||
@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
fimBaseUrl := info.BaseUrl
|
fimBaseUrl := info.ChannelBaseUrl
|
||||||
if !strings.HasSuffix(info.BaseUrl, "/beta") {
|
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||||
fimBaseUrl += "/beta"
|
fimBaseUrl += "/beta"
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
switch a.BotType {
|
switch a.BotType {
|
||||||
case BotTypeWorkFlow:
|
case BotTypeWorkFlow:
|
||||||
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
|
||||||
case BotTypeCompletion:
|
case BotTypeCompletion:
|
||||||
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
|
||||||
case BotTypeAgent:
|
case BotTypeAgent:
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -23,7 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
|
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
|
||||||
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
|
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
|
||||||
switch media.Type {
|
switch media.Type {
|
||||||
case dto.ContentTypeImageURL:
|
case dto.ContentTypeImageURL:
|
||||||
// Decode base64 data
|
// Decode base64 data
|
||||||
@@ -37,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
// Decode base64 string
|
// Decode base64 string
|
||||||
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to decode base64: " + err.Error())
|
common.SysLog("failed to decode base64: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create temporary file
|
// Create temporary file
|
||||||
tempFile, err := os.CreateTemp("", "dify-upload-*")
|
tempFile, err := os.CreateTemp("", "dify-upload-*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to create temp file: " + err.Error())
|
common.SysLog("failed to create temp file: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer tempFile.Close()
|
defer tempFile.Close()
|
||||||
@@ -52,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
|
|
||||||
// Write decoded data to temp file
|
// Write decoded data to temp file
|
||||||
if _, err := tempFile.Write(decodedData); err != nil {
|
if _, err := tempFile.Write(decodedData); err != nil {
|
||||||
logger.SysError("failed to write to temp file: " + err.Error())
|
common.SysLog("failed to write to temp file: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
|
|
||||||
// Add user field
|
// Add user field
|
||||||
if err := writer.WriteField("user", user); err != nil {
|
if err := writer.WriteField("user", user); err != nil {
|
||||||
logger.SysError("failed to add user field: " + err.Error())
|
common.SysLog("failed to add user field: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
// Create form file
|
// Create form file
|
||||||
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
|
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to create form file: " + err.Error())
|
common.SysLog("failed to create form file: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy file content to form
|
// Copy file content to form
|
||||||
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
|
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
|
||||||
logger.SysError("failed to copy file content: " + err.Error())
|
common.SysLog("failed to copy file content: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
writer.Close()
|
writer.Close()
|
||||||
@@ -89,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
req, err := http.NewRequest("POST", uploadUrl, body)
|
req, err := http.NewRequest("POST", uploadUrl, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to create request: " + err.Error())
|
common.SysLog("failed to create request: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
client := service.GetHttpClient()
|
client := service.GetHttpClient()
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to send request: " + err.Error())
|
common.SysLog("failed to send request: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -110,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
logger.SysError("failed to decode response: " + err.Error())
|
common.SysLog("failed to decode response: " + err.Error())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
var difyResponse DifyChunkChatCompletionResponse
|
var difyResponse DifyChunkChatCompletionResponse
|
||||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||||
@@ -240,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
}
|
}
|
||||||
err = helper.ObjectData(c, openaiResponse)
|
err = helper.ObjectData(c, openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
|
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
|
||||||
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||||
@@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if info.IsGeminiBatchEmbedding {
|
if info.IsGeminiBatchEmbedding {
|
||||||
action = "batchEmbedContents"
|
action = "batchEmbedContents"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
@@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
info.DisablePing = true
|
info.DisablePing = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -994,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|||||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
err := handleFinalStream(c, info, response)
|
err := handleFinalStream(c, info, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("send final response failed: " + err.Error())
|
common.SysLog("send final response failed: " + err.Error())
|
||||||
}
|
}
|
||||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||||
// helper.Done(c)
|
// helper.Done(c)
|
||||||
@@ -1042,19 +1042,19 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
responseBody, err = common.Marshal(fullTextResponse)
|
responseBody, err = common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
|
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
|
||||||
claudeRespStr, err := common.Marshal(claudeResp)
|
claudeRespStr, err := common.Marshal(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
responseBody = claudeRespStr
|
responseBody = claudeRespStr
|
||||||
case relaycommon.RelayFormatGemini:
|
case types.RelayFormatGemini:
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
|
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
return "", errors.New("invalid relay mode")
|
return "", errors.New("invalid relay mode")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,5 +6,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
|
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
|
||||||
suffix = "embeddings"
|
suffix = "embeddings"
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
|
fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil
|
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||||
default:
|
default:
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
adaptor := openai.Adaptor{}
|
adaptor := openai.Adaptor{}
|
||||||
return adaptor.DoResponse(c, resp, info)
|
return adaptor.DoResponse(c, resp, info)
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == types.RelayFormatClaude {
|
||||||
return info.BaseUrl + "/v1/chat/completions", nil
|
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
||||||
}
|
}
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
return info.BaseUrl + "/api/embed", nil
|
return info.ChannelBaseUrl + "/api/embed", nil
|
||||||
default:
|
default:
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
|
||||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
|
||||||
baseUrl = "wss://" + baseUrl
|
baseUrl = "wss://" + baseUrl
|
||||||
info.BaseUrl = baseUrl
|
info.ChannelBaseUrl = baseUrl
|
||||||
} else if strings.HasPrefix(info.BaseUrl, "http://") {
|
} else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
|
||||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
|
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
|
||||||
baseUrl = "ws://" + baseUrl
|
baseUrl = "ws://" + baseUrl
|
||||||
info.BaseUrl = baseUrl
|
info.ChannelBaseUrl = baseUrl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch info.ChannelType {
|
switch info.ChannelType {
|
||||||
@@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == types.RelayFormatClaude {
|
||||||
task = strings.TrimPrefix(task, "messages")
|
task = strings.TrimPrefix(task, "messages")
|
||||||
task = "chat/completions" + task
|
task = "chat/completions" + task
|
||||||
}
|
}
|
||||||
@@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
responsesApiVersion := "preview"
|
responsesApiVersion := "preview"
|
||||||
|
|
||||||
subUrl := "/openai/v1/responses"
|
subUrl := "/openai/v1/responses"
|
||||||
if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") {
|
if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
|
||||||
subUrl = "/openai/responses"
|
subUrl = "/openai/responses"
|
||||||
responsesApiVersion = apiVersion
|
responsesApiVersion = apiVersion
|
||||||
}
|
}
|
||||||
@@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
model_ := info.UpstreamModelName
|
model_ := info.UpstreamModelName
|
||||||
@@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||||
}
|
}
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
|
||||||
case constant.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
return minimax.GetRequestURL(info)
|
return minimax.GetRequestURL(info)
|
||||||
case constant.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
url := info.BaseUrl
|
url := info.ChannelBaseUrl
|
||||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||||
return url, nil
|
return url, nil
|
||||||
default:
|
default:
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -22,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string
|
|||||||
info.SendResponseCount++
|
info.SendResponseCount++
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
return handleClaudeFormat(c, data, info)
|
return handleClaudeFormat(c, data, info)
|
||||||
case relaycommon.RelayFormatGemini:
|
case types.RelayFormatGemini:
|
||||||
return handleGeminiFormat(c, data, info)
|
return handleGeminiFormat(c, data, info)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -111,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
|
|||||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||||
// 一次性解析失败,逐个解析
|
// 一次性解析失败,逐个解析
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
for _, item := range streamItems {
|
for _, item := range streamItems {
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||||
logger.SysError("error processing stream response: " + err.Error())
|
common.SysLog("error processing stream response: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -147,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
|
|||||||
var streamResponses []dto.CompletionsStreamResponse
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||||
// 一次性解析失败,逐个解析
|
// 一次性解析失败,逐个解析
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
for _, item := range streamItems {
|
for _, item := range streamItems {
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||||
@@ -202,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
|||||||
usage *dto.Usage, containStreamUsage bool) {
|
usage *dto.Usage, containStreamUsage bool) {
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||||
response.SetSystemFingerprint(systemFingerprint)
|
response.SetSystemFingerprint(systemFingerprint)
|
||||||
@@ -210,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
|||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
info.ClaudeConvertInfo.Done = true
|
info.ClaudeConvertInfo.Done = true
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
|||||||
_ = helper.ClaudeData(c, *resp)
|
_ = helper.ClaudeData(c, *resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
case relaycommon.RelayFormatGemini:
|
case types.RelayFormatGemini:
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
|||||||
|
|
||||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling gemini response: " + err.Error())
|
common.SysLog("error marshalling gemini response: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
if lastStreamData != "" {
|
if lastStreamData != "" {
|
||||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error handling stream format: " + err.Error())
|
common.SysLog("error handling stream format: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(data) > 0 {
|
if len(data) > 0 {
|
||||||
@@ -147,7 +147,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
|
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||||
if shouldSendLastResp {
|
if shouldSendLastResp {
|
||||||
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||||
}
|
}
|
||||||
@@ -211,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
if forceFormat {
|
if forceFormat {
|
||||||
responseBody, err = common.Marshal(simpleResponse)
|
responseBody, err = common.Marshal(simpleResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case types.RelayFormatClaude:
|
||||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||||
claudeRespStr, err := common.Marshal(claudeResp)
|
claudeRespStr, err := common.Marshal(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
responseBody = claudeRespStr
|
responseBody = claudeRespStr
|
||||||
case relaycommon.RelayFormatGemini:
|
case types.RelayFormatGemini:
|
||||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||||
geminiRespStr, err := common.Marshal(geminiResp)
|
geminiRespStr, err := common.Marshal(geminiResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -59,7 +58,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
|||||||
go func() {
|
go func() {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error reading stream response: " + err.Error())
|
common.SysLog("error reading stream response: " + err.Error())
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -67,7 +66,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
|||||||
var palmResponse PaLMChatResponse
|
var palmResponse PaLMChatResponse
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -79,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
|||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ type TaskAdaptor struct {
|
|||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
a.baseURL = info.BaseUrl
|
a.baseURL = info.ChannelBaseUrl
|
||||||
|
|
||||||
// apiKey format: "access_key|secret_key"
|
// apiKey format: "access_key|secret_key"
|
||||||
keyParts := strings.Split(info.ApiKey, "|")
|
keyParts := strings.Split(info.ApiKey, "|")
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ type TaskAdaptor struct {
|
|||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
a.baseURL = info.BaseUrl
|
a.baseURL = info.ChannelBaseUrl
|
||||||
a.apiKey = info.ApiKey
|
a.apiKey = info.ApiKey
|
||||||
|
|
||||||
// apiKey format: "access_key|secret_key"
|
// apiKey format: "access_key|secret_key"
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -60,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
baseURL := info.BaseUrl
|
baseURL := info.ChannelBaseUrl
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
@@ -140,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
|||||||
|
|
||||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("Get Task error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ type TaskAdaptor struct {
|
|||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
a.baseURL = info.BaseUrl
|
a.baseURL = info.ChannelBaseUrl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -107,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
|||||||
var tencentResponse TencentChatResponse
|
var tencentResponse TencentChatResponse
|
||||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,12 +117,12 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
|||||||
|
|
||||||
err = helper.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
logger.SysError("error reading stream: " + err.Error())
|
common.SysLog("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|||||||
@@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesEdits:
|
||||||
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
xaiRequest := ImageRequest{
|
xaiRequest := ImageRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Prompt: request.Prompt,
|
Prompt: request.Prompt,
|
||||||
N: request.N,
|
N: int(request.N),
|
||||||
ResponseFormat: request.ResponseFormat,
|
ResponseFormat: request.ResponseFormat,
|
||||||
}
|
}
|
||||||
return xaiRequest, nil
|
return xaiRequest, nil
|
||||||
@@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
@@ -48,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||||
err := json.Unmarshal([]byte(data), &xAIResp)
|
err := json.Unmarshal([]byte(data), &xAIResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||||
err = helper.ObjectData(c, openaiResponse)
|
err = helper.ObjectData(c, openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -144,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
|||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
@@ -219,20 +218,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
|||||||
for {
|
for {
|
||||||
_, msg, err := conn.ReadMessage()
|
_, msg, err := conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error reading stream response: " + err.Error())
|
common.SysLog("error reading stream response: " + err.Error())
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var response XunfeiChatResponse
|
var response XunfeiChatResponse
|
||||||
err = json.Unmarshal(msg, &response)
|
err = json.Unmarshal(msg, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
dataChan <- response
|
dataChan <- response
|
||||||
if response.Payload.Choices.Status == 2 {
|
if response.Payload.Choices.Status == 2 {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error closing websocket connection: " + err.Error())
|
common.SysLog("error closing websocket connection: " + err.Error())
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -283,6 +282,6 @@ func getAPIVersion(c *gin.Context, modelName string) string {
|
|||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
apiVersion = "v1.1"
|
apiVersion = "v1.1"
|
||||||
logger.SysLog("api_version not found, using default: " + apiVersion)
|
common.SysLog("api_version not found, using default: " + apiVersion)
|
||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
method = "sse-invoke"
|
method = "sse-invoke"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -40,7 +39,7 @@ func getZhipuToken(apikey string) string {
|
|||||||
|
|
||||||
split := strings.Split(apikey, ".")
|
split := strings.Split(apikey, ".")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
logger.SysError("invalid zhipu key: " + apikey)
|
common.SysLog("invalid zhipu key: " + apikey)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
response := streamResponseZhipu2OpenAI(data)
|
response := streamResponseZhipu2OpenAI(data)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
@@ -197,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
var zhipuResponse ZhipuStreamMetaResponse
|
var zhipuResponse ZhipuStreamMetaResponse
|
||||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
common.SysLog("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
usage = zhipuUsage
|
usage = zhipuUsage
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
|
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl)
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ type ChannelMeta struct {
|
|||||||
ChannelOtherSettings dto.ChannelOtherSettings
|
ChannelOtherSettings dto.ChannelOtherSettings
|
||||||
UpstreamModelName string
|
UpstreamModelName string
|
||||||
IsModelMapped bool
|
IsModelMapped bool
|
||||||
|
SupportStreamOptions bool // 是否支持流式选项
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
@@ -86,9 +87,9 @@ type RelayInfo struct {
|
|||||||
RelayMode int
|
RelayMode int
|
||||||
OriginModelName string
|
OriginModelName string
|
||||||
//RecodeModelName string
|
//RecodeModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
PromptTokens int
|
PromptTokens int
|
||||||
SupportStreamOptions bool
|
//SupportStreamOptions bool
|
||||||
ShouldIncludeUsage bool
|
ShouldIncludeUsage bool
|
||||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||||
ClientWs *websocket.Conn
|
ClientWs *websocket.Conn
|
||||||
@@ -135,6 +136,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
|||||||
ParamOverride: paramOverride,
|
ParamOverride: paramOverride,
|
||||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||||
IsModelMapped: false,
|
IsModelMapped: false,
|
||||||
|
SupportStreamOptions: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||||
@@ -146,6 +148,10 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
channelMeta.ChannelOtherSettings = channelOtherSettings
|
channelMeta.ChannelOtherSettings = channelOtherSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if streamSupportedChannels[channelMeta.ChannelType] {
|
||||||
|
channelMeta.SupportStreamOptions = true
|
||||||
|
}
|
||||||
info.ChannelMeta = channelMeta
|
info.ChannelMeta = channelMeta
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,6 +274,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
|||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isStream := false
|
||||||
|
|
||||||
|
if request != nil {
|
||||||
|
isStream = request.IsStream(c)
|
||||||
|
}
|
||||||
|
|
||||||
// firstResponseTime = time.Now() - 1 second
|
// firstResponseTime = time.Now() - 1 second
|
||||||
|
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
@@ -289,7 +301,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
|||||||
isFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
IsStream: request.IsStream(c),
|
IsStream: isStream,
|
||||||
|
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
FirstResponseTime: startTime.Add(-time.Second),
|
FirstResponseTime: startTime.Add(-time.Second),
|
||||||
@@ -339,6 +351,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
|||||||
return GenRelayInfoResponses(c, request), nil
|
return GenRelayInfoResponses(c, request), nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("request is not a OpenAIResponsesRequest")
|
return nil, errors.New("request is not a OpenAIResponsesRequest")
|
||||||
|
case types.RelayFormatTask:
|
||||||
|
return genBaseRelayInfo(c, nil), nil
|
||||||
|
case types.RelayFormatMjProxy:
|
||||||
|
return genBaseRelayInfo(c, nil), nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("invalid relay format")
|
return nil, errors.New("invalid relay format")
|
||||||
}
|
}
|
||||||
@@ -367,11 +383,15 @@ type TaskRelayInfo struct {
|
|||||||
ConsumeQuota bool
|
ConsumeQuota bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
|
||||||
info := &TaskRelayInfo{
|
relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||||
RelayInfo: GenRelayInfo(c),
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return info
|
info := &TaskRelayInfo{
|
||||||
|
RelayInfo: relayInfo,
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskSubmitReq struct {
|
type TaskSubmitReq struct {
|
||||||
|
|||||||
+23
-23
@@ -53,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
var imageRatio float64
|
var imageRatio float64
|
||||||
var cacheCreationRatio float64
|
var cacheCreationRatio float64
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
|
||||||
if meta.MaxTokens != 0 {
|
if meta.MaxTokens != 0 {
|
||||||
preConsumedTokens = promptTokens + meta.MaxTokens
|
preConsumedTokens += meta.MaxTokens
|
||||||
}
|
}
|
||||||
var success bool
|
var success bool
|
||||||
var matchName string
|
var matchName string
|
||||||
@@ -102,27 +102,27 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||||
//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
|
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
|
||||||
// groupRatioInfo := HandleGroupRatio(c, info)
|
groupRatioInfo := HandleGroupRatio(c, info)
|
||||||
//
|
|
||||||
// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||||
// // 如果没有配置价格,则使用默认价格
|
// 如果没有配置价格,则使用默认价格
|
||||||
// if !success {
|
if !success {
|
||||||
// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
|
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
|
||||||
// if !ok {
|
if !ok {
|
||||||
// modelPrice = 0.1
|
modelPrice = 0.1
|
||||||
// } else {
|
} else {
|
||||||
// modelPrice = defaultPrice
|
modelPrice = defaultPrice
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||||
// priceData := types.PerCallPriceData{
|
priceData := types.PerCallPriceData{
|
||||||
// ModelPrice: modelPrice,
|
ModelPrice: modelPrice,
|
||||||
// Quota: quota,
|
Quota: quota,
|
||||||
// GroupRatioInfo: groupRatioInfo,
|
GroupRatioInfo: groupRatioInfo,
|
||||||
// }
|
}
|
||||||
// return priceData
|
return priceData
|
||||||
//}
|
}
|
||||||
|
|
||||||
func ContainPriceOrRatio(modelName string) bool {
|
func ContainPriceOrRatio(modelName string) bool {
|
||||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
|||||||
case types.RelayFormatOpenAIAudio:
|
case types.RelayFormatOpenAIAudio:
|
||||||
request, err = GetAndValidAudioRequest(c, relayMode)
|
request, err = GetAndValidAudioRequest(c, relayMode)
|
||||||
case types.RelayFormatOpenAIRealtime:
|
case types.RelayFormatOpenAIRealtime:
|
||||||
// nothing to do, no request body
|
request = &dto.BaseRequest{}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported relay format: %s", format)
|
return nil, fmt.Errorf("unsupported relay format: %s", format)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
@@ -171,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
//group := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
|
||||||
var swapFaceRequest dto.SwapFaceRequest
|
var swapFaceRequest dto.SwapFaceRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -188,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
}
|
}
|
||||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||||
|
|
||||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
priceData := helper.ModelPriceHelperPerCall(c, info)
|
||||||
|
|
||||||
userQuota, err := model.GetUserQuota(userId, false)
|
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
@@ -213,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||||
other := service.GenerateMjOtherInfo(priceData)
|
other := service.GenerateMjOtherInfo(priceData)
|
||||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||||
ChannelId: channelId,
|
ChannelId: info.ChannelId,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
Quota: priceData.Quota,
|
Quota: priceData.Quota,
|
||||||
Content: logContent,
|
Content: logContent,
|
||||||
TokenId: tokenId,
|
TokenId: info.TokenId,
|
||||||
UserQuota: userQuota,
|
Group: info.UsingGroup,
|
||||||
Group: relayInfo.UsingGroup,
|
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
midjResponse := &mjResp.Response
|
midjResponse := &mjResp.Response
|
||||||
midjourneyTask := &model.Midjourney{
|
midjourneyTask := &model.Midjourney{
|
||||||
UserId: userId,
|
UserId: info.UserId,
|
||||||
Code: midjResponse.Code,
|
Code: midjResponse.Code,
|
||||||
Action: constant.MjActionSwapFace,
|
Action: constant.MjActionSwapFace,
|
||||||
MjId: midjResponse.Result,
|
MjId: midjResponse.Result,
|
||||||
@@ -246,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
PromptEn: "",
|
PromptEn: "",
|
||||||
Description: midjResponse.Description,
|
Description: midjResponse.Description,
|
||||||
State: "",
|
State: "",
|
||||||
SubmitTime: startTime,
|
SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
|
||||||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||||
FinishTime: 0,
|
FinishTime: 0,
|
||||||
ImageUrl: "",
|
ImageUrl: "",
|
||||||
@@ -370,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||||||
|
|
||||||
//tokenId := c.GetInt("token_id")
|
|
||||||
//channelType := c.GetInt("channel")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
|
||||||
consumeQuota := true
|
consumeQuota := true
|
||||||
var midjRequest dto.MidjourneyRequest
|
var midjRequest dto.MidjourneyRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
@@ -385,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||||
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||||
if mjErr != nil {
|
if mjErr != nil {
|
||||||
return mjErr
|
return mjErr
|
||||||
}
|
}
|
||||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
|
||||||
}
|
}
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||||
midjRequest.Action = constant.MjActionVideo
|
midjRequest.Action = constant.MjActionVideo
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||||
if midjRequest.Prompt == "" {
|
if midjRequest.Prompt == "" {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
||||||
}
|
}
|
||||||
midjRequest.Action = constant.MjActionImagine
|
midjRequest.Action = constant.MjActionImagine
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionDescribe
|
midjRequest.Action = constant.MjActionDescribe
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionEdits
|
midjRequest.Action = constant.MjActionEdits
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||||
midjRequest.Action = constant.MjActionShorten
|
midjRequest.Action = constant.MjActionShorten
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionBlend
|
midjRequest.Action = constant.MjActionBlend
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionUpload
|
midjRequest.Action = constant.MjActionUpload
|
||||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
mjId := ""
|
mjId := ""
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
|
||||||
if midjRequest.TaskId == "" {
|
if midjRequest.TaskId == "" {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||||
} else if midjRequest.Action == "" {
|
} else if midjRequest.Action == "" {
|
||||||
@@ -423,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
//action = midjRequest.Action
|
//action = midjRequest.Action
|
||||||
mjId = midjRequest.TaskId
|
mjId = midjRequest.TaskId
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||||
if midjRequest.Content == "" {
|
if midjRequest.Content == "" {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||||
}
|
}
|
||||||
@@ -433,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
mjId = params.TaskId
|
mjId = params.TaskId
|
||||||
midjRequest.Action = params.Action
|
midjRequest.Action = params.Action
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
|
||||||
//if midjRequest.MaskBase64 == "" {
|
//if midjRequest.MaskBase64 == "" {
|
||||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
||||||
//}
|
//}
|
||||||
mjId = midjRequest.TaskId
|
mjId = midjRequest.TaskId
|
||||||
midjRequest.Action = constant.MjActionModal
|
midjRequest.Action = constant.MjActionModal
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||||
midjRequest.Action = constant.MjActionVideo
|
midjRequest.Action = constant.MjActionVideo
|
||||||
if midjRequest.TaskId == "" {
|
if midjRequest.TaskId == "" {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||||
@@ -449,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
mjId = midjRequest.TaskId
|
mjId = midjRequest.TaskId
|
||||||
}
|
}
|
||||||
|
|
||||||
originTask := model.GetByMJId(userId, mjId)
|
originTask := model.GetByMJId(relayInfo.UserId, mjId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
if setting.MjActionCheckSuccessEnabled {
|
if setting.MjActionCheckSuccessEnabled {
|
||||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -497,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
|
|
||||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||||
|
|
||||||
userQuota, err := model.GetUserQuota(userId, false)
|
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
@@ -522,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||||
other := service.GenerateMjOtherInfo(priceData)
|
other := service.GenerateMjOtherInfo(priceData)
|
||||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||||
ChannelId: channelId,
|
ChannelId: relayInfo.ChannelId,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
Quota: priceData.Quota,
|
Quota: priceData.Quota,
|
||||||
Content: logContent,
|
Content: logContent,
|
||||||
TokenId: relayInfo.TokenId,
|
TokenId: relayInfo.TokenId,
|
||||||
UserQuota: userQuota,
|
Group: relayInfo.UsingGroup,
|
||||||
Group: group,
|
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -551,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||||
// other: 提交错误,description为错误描述
|
// other: 提交错误,description为错误描述
|
||||||
midjourneyTask := &model.Midjourney{
|
midjourneyTask := &model.Midjourney{
|
||||||
UserId: userId,
|
UserId: relayInfo.UserId,
|
||||||
Code: midjResponse.Code,
|
Code: midjResponse.Code,
|
||||||
Action: midjRequest.Action,
|
Action: midjRequest.Action,
|
||||||
MjId: midjResponse.Result,
|
MjId: midjResponse.Result,
|
||||||
@@ -573,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
//无实例账号自动禁用渠道(No available account instance)
|
//无实例账号自动禁用渠道(No available account instance)
|
||||||
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("get_channel_null: " + err.Error())
|
common.SysLog("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
||||||
@@ -44,6 +44,26 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
includeUsage := true
|
||||||
|
// 判断用户是否需要返回使用情况
|
||||||
|
if textRequest.StreamOptions != nil {
|
||||||
|
includeUsage = textRequest.StreamOptions.IncludeUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||||
|
if !info.SupportStreamOptions || !textRequest.Stream {
|
||||||
|
textRequest.StreamOptions = nil
|
||||||
|
} else {
|
||||||
|
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||||
|
if constant.ForceStreamOption {
|
||||||
|
textRequest.StreamOptions = &dto.StreamOptions{
|
||||||
|
IncludeUsage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info.ShouldIncludeUsage = includeUsage
|
||||||
|
|
||||||
adaptor := GetAdaptor(info.ApiType)
|
adaptor := GetAdaptor(info.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||||
|
|||||||
+7
-5
@@ -10,7 +10,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
@@ -28,7 +27,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
if platform == "" {
|
if platform == "" {
|
||||||
platform = GetTaskPlatform(c)
|
platform = GetTaskPlatform(c)
|
||||||
}
|
}
|
||||||
relayInfo := relaycommon.GenTaskRelayInfo(c)
|
|
||||||
|
relayInfo, err := relaycommon.GenTaskRelayInfo(c)
|
||||||
|
if err != nil {
|
||||||
|
return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
adaptor := GetTaskAdaptor(platform)
|
adaptor := GetTaskAdaptor(platform)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
@@ -98,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
c.Set("channel_id", originTask.ChannelId)
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|
||||||
relayInfo.BaseUrl = channel.GetBaseURL()
|
relayInfo.ChannelBaseUrl = channel.GetBaseURL()
|
||||||
relayInfo.ChannelId = originTask.ChannelId
|
relayInfo.ChannelId = originTask.ChannelId
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -128,7 +131,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
|
|
||||||
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
@@ -150,7 +153,6 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
Quota: quota,
|
Quota: quota,
|
||||||
Content: logContent,
|
Content: logContent,
|
||||||
TokenId: relayInfo.TokenId,
|
TokenId: relayInfo.TokenId,
|
||||||
UserQuota: userQuota,
|
|
||||||
Group: relayInfo.UsingGroup,
|
Group: relayInfo.UsingGroup,
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
|
|||||||
+10
-34
@@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
@@ -12,58 +11,35 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
|
func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||||
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
err := helper.ModelMappedHelper(c, relayInfo, nil)
|
adaptor := GetAdaptor(info.ApiType)
|
||||||
if err != nil {
|
|
||||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
|
||||||
}
|
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
|
||||||
}
|
|
||||||
|
|
||||||
// pre-consume quota 预消耗配额
|
|
||||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
|
||||||
if newAPIError != nil {
|
|
||||||
return newAPIError
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if newAPIError != nil {
|
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(info)
|
||||||
//var requestBody io.Reader
|
//var requestBody io.Reader
|
||||||
//firstWssRequest, _ := c.Get("first_wss_request")
|
//firstWssRequest, _ := c.Get("first_wss_request")
|
||||||
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
resp, err := adaptor.DoRequest(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
relayInfo.TargetWs = resp.(*websocket.Conn)
|
info.TargetWs = resp.(*websocket.Conn)
|
||||||
defer relayInfo.TargetWs.Close()
|
defer info.TargetWs.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, nil, info)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
}
|
}
|
||||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
|
||||||
userQuota, priceData, "")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
+3
-3
@@ -3,12 +3,12 @@ package router
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/logger"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||||
@@ -19,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
|||||||
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
||||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||||
frontendBaseUrl = ""
|
frontendBaseUrl = ""
|
||||||
logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
|
common.SysLog("FRONTEND_BASE_URL is ignored on master node")
|
||||||
}
|
}
|
||||||
if frontendBaseUrl == "" {
|
if frontendBaseUrl == "" {
|
||||||
SetWebRouter(router, buildFS, indexPage)
|
SetWebRouter(router, buildFS, indexPage)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
|||||||
|
|
||||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
||||||
if setting.EnableWorker() {
|
if setting.EnableWorker() {
|
||||||
logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
||||||
req := &WorkerRequest{
|
req := &WorkerRequest{
|
||||||
URL: originUrl,
|
URL: originUrl,
|
||||||
Key: setting.WorkerValidKey,
|
Key: setting.WorkerValidKey,
|
||||||
}
|
}
|
||||||
return DoWorkerRequest(req)
|
return DoWorkerRequest(req)
|
||||||
} else {
|
} else {
|
||||||
logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
||||||
return http.Get(originUrl)
|
return http.Get(originUrl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-3
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -59,7 +58,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
|
|||||||
lowerText := strings.ToLower(text)
|
lowerText := strings.ToLower(text)
|
||||||
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
||||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||||
logger.SysLog(fmt.Sprintf("error: %s", text))
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||||
text = "请求上游地址失败"
|
text = "请求上游地址失败"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,7 +138,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
|
|||||||
text := err.Error()
|
text := err.Error()
|
||||||
lowerText := strings.ToLower(text)
|
lowerText := strings.ToLower(text)
|
||||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||||
logger.SysLog(fmt.Sprintf("error: %s", text))
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||||
text = "请求上游地址失败"
|
text = "请求上游地址失败"
|
||||||
}
|
}
|
||||||
//避免暴露内部错误
|
//避免暴露内部错误
|
||||||
|
|||||||
+5
-5
@@ -8,8 +8,8 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/logger"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/image/webp"
|
"golang.org/x/image/webp"
|
||||||
@@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
|||||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||||
response, err := DoDownloadRequest(imageUrl)
|
response, err := DoDownloadRequest(imageUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||||
return image.Config{}, "", err
|
return image.Config{}, "", err
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
|||||||
|
|
||||||
var readData []byte
|
var readData []byte
|
||||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||||
logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||||
|
|
||||||
// 从response.Body读取更多的数据直到达到当前的限制
|
// 从response.Body读取更多的数据直到达到当前的限制
|
||||||
additionalData := make([]byte, limit-int64(len(readData)))
|
additionalData := make([]byte, limit-int64(len(readData)))
|
||||||
@@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
|||||||
config, format, err := image.DecodeConfig(reader)
|
config, format, err := image.DecodeConfig(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
config, err = webp.DecodeConfig(reader)
|
config, err = webp.DecodeConfig(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||||
logger.SysLog(err.Error())
|
common.SysLog(err.Error())
|
||||||
}
|
}
|
||||||
format = "webp"
|
format = "webp"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
|
func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
|
||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = priceData.ModelPrice
|
other["model_price"] = priceData.ModelPrice
|
||||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -213,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
resp, err := GetHttpClient().Do(req)
|
resp, err := GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("do request failed: " + err.Error())
|
common.SysLog("do request failed: " + err.Error())
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
}
|
}
|
||||||
statusCode := resp.StatusCode
|
statusCode := resp.StatusCode
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -19,7 +20,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
|
|||||||
|
|
||||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error return pre-consumed quota: " + err.Error())
|
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -32,9 +31,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
|||||||
var tokenEncoderMutex sync.RWMutex
|
var tokenEncoderMutex sync.RWMutex
|
||||||
|
|
||||||
func InitTokenEncoders() {
|
func InitTokenEncoders() {
|
||||||
logger.SysLog("initializing token encoders")
|
common.SysLog("initializing token encoders")
|
||||||
defaultTokenEncoder = codec.NewCl100kBase()
|
defaultTokenEncoder = codec.NewCl100kBase()
|
||||||
logger.SysLog("token encoders initialized")
|
common.SysLog("token encoders initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenEncoder(model string) tokenizer.Codec {
|
func getTokenEncoder(model string) tokenizer.Codec {
|
||||||
@@ -158,7 +157,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|||||||
if strings.HasPrefix(fileMeta.Data, "http") {
|
if strings.HasPrefix(fileMeta.Data, "http") {
|
||||||
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
||||||
} else {
|
} else {
|
||||||
logger.SysLog(fmt.Sprintf("decoding image"))
|
common.SysLog(fmt.Sprintf("decoding image"))
|
||||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -248,6 +247,11 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
if meta == nil {
|
if meta == nil {
|
||||||
return 0, errors.New("token count meta is nil")
|
return 0, errors.New("token count meta is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
tkm := CountTextToken(meta.CombineText, model)
|
tkm := CountTextToken(meta.CombineText, model)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/logger"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -13,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) {
|
|||||||
user := model.GetRootUser().ToBaseUser()
|
user := model.GetRootUser().ToBaseUser()
|
||||||
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
|
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
|||||||
// Check notification limit
|
// Check notification limit
|
||||||
canSend, err := CheckNotificationLimit(userId, data.Type)
|
canSend, err := CheckNotificationLimit(userId, data.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !canSend {
|
if !canSend {
|
||||||
@@ -38,14 +37,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
|||||||
// check setting email
|
// check setting email
|
||||||
userEmail = userSetting.NotificationEmail
|
userEmail = userSetting.NotificationEmail
|
||||||
if userEmail == "" {
|
if userEmail == "" {
|
||||||
logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
|
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return sendEmailNotify(userEmail, data)
|
return sendEmailNotify(userEmail, data)
|
||||||
case dto.NotifyTypeWebhook:
|
case dto.NotifyTypeWebhook:
|
||||||
webhookURLStr := userSetting.WebhookUrl
|
webhookURLStr := userSetting.WebhookUrl
|
||||||
if webhookURLStr == "" {
|
if webhookURLStr == "" {
|
||||||
logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -2,7 +2,7 @@ package setting
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Chats = []map[string]string{
|
var Chats = []map[string]string{
|
||||||
@@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error {
|
|||||||
func Chats2JsonString() string {
|
func Chats2JsonString() string {
|
||||||
jsonBytes, err := json.Marshal(Chats)
|
jsonBytes, err := json.Marshal(Chats)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling chats: " + err.Error())
|
common.SysLog("error marshalling chats: " + err.Error())
|
||||||
return "[]"
|
return "[]"
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
|
|||||||
// 如果找到配置项,则更新配置
|
// 如果找到配置项,则更新配置
|
||||||
if len(configMap) > 0 {
|
if len(configMap) > 0 {
|
||||||
if err := updateConfigFromMap(config, configMap); err != nil {
|
if err := updateConfigFromMap(config, configMap); err != nil {
|
||||||
logger.SysError("failed to update config " + name + ": " + err.Error())
|
common.SysError("failed to update config " + name + ": " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string {
|
|||||||
|
|
||||||
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
|
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling model ratio: " + err.Error())
|
common.SysLog("error marshalling model ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package ratio_setting
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ func CacheRatio2JSONString() string {
|
|||||||
defer cacheRatioMapMutex.RUnlock()
|
defer cacheRatioMapMutex.RUnlock()
|
||||||
jsonBytes, err := json.Marshal(cacheRatioMap)
|
jsonBytes, err := json.Marshal(cacheRatioMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling cache ratio: " + err.Error())
|
common.SysLog("error marshalling cache ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package ratio_setting
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ func GroupRatio2JSONString() string {
|
|||||||
|
|
||||||
jsonBytes, err := json.Marshal(groupRatio)
|
jsonBytes, err := json.Marshal(groupRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling model ratio: " + err.Error())
|
common.SysLog("error marshalling model ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
@@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 {
|
|||||||
|
|
||||||
ratio, ok := groupRatio[name]
|
ratio, ok := groupRatio[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.SysError("group ratio not found: " + name)
|
common.SysLog("group ratio not found: " + name)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return ratio
|
return ratio
|
||||||
@@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string {
|
|||||||
|
|
||||||
jsonBytes, err := json.Marshal(GroupGroupRatio)
|
jsonBytes, err := json.Marshal(GroupGroupRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling group-group ratio: " + err.Error())
|
common.SysLog("error marshalling group-group ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package setting
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/logger"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string {
|
|||||||
|
|
||||||
jsonBytes, err := json.Marshal(userUsableGroups)
|
jsonBytes, err := json.Marshal(userUsableGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling user groups: " + err.Error())
|
common.SysLog("error marshalling user groups: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,4 +12,7 @@ const (
|
|||||||
RelayFormatOpenAIRealtime = "openai_realtime"
|
RelayFormatOpenAIRealtime = "openai_realtime"
|
||||||
RelayFormatRerank = "rerank"
|
RelayFormatRerank = "rerank"
|
||||||
RelayFormatEmbedding = "embedding"
|
RelayFormatEmbedding = "embedding"
|
||||||
|
|
||||||
|
RelayFormatTask = "task"
|
||||||
|
RelayFormatMjProxy = "mj_proxy"
|
||||||
)
|
)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user