feat: track upstream request ID and prevent response header override
When proxying through another new-api instance, the upstream X-Oneapi-Request-Id was overwriting the local one in client responses. This adds a new `upstream_request_id` field to the logs table, captures the upstream ID during relay, and filters it from being copied back to the client. Frontend gains search/filter and detail display support.
This commit is contained in:
+19
-8
@@ -35,8 +35,9 @@ type Log struct {
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
|
||||
Other string `json:"other"`
|
||||
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
|
||||
UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
// don't use iota, avoid change log type value
|
||||
@@ -147,6 +148,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
username := c.GetString("username")
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
@@ -177,8 +179,9 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestId: requestId,
|
||||
Other: otherStr,
|
||||
RequestId: requestId,
|
||||
UpstreamRequestId: upstreamRequestId,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -208,6 +211,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
|
||||
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
username := c.GetString("username")
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
upstreamRequestId := c.GetString(common.UpstreamRequestIdKey)
|
||||
otherStr := common.MapToJsonStr(params.Other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
@@ -238,8 +242,9 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestId: requestId,
|
||||
Other: otherStr,
|
||||
RequestId: requestId,
|
||||
UpstreamRequestId: upstreamRequestId,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -295,7 +300,7 @@ func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string, upstreamRequestId string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB
|
||||
@@ -315,6 +320,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
if upstreamRequestId != "" {
|
||||
tx = tx.Where("logs.upstream_request_id = ?", upstreamRequestId)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("logs.created_at >= ?", startTimestamp)
|
||||
}
|
||||
@@ -381,7 +389,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
|
||||
const logSearchCountLimit = 10000
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string, upstreamRequestId string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB.Where("logs.user_id = ?", userId)
|
||||
@@ -402,6 +410,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if requestId != "" {
|
||||
tx = tx.Where("logs.request_id = ?", requestId)
|
||||
}
|
||||
if upstreamRequestId != "" {
|
||||
tx = tx.Where("logs.upstream_request_id = ?", upstreamRequestId)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("logs.created_at >= ?", startTimestamp)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user