fix: reuse stream scanner buffer in channel handlers (#5225)
This commit is contained in:
@@ -30,7 +30,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package cohere
|
package cohere
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -86,7 +85,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
responseText := ""
|
responseText := ""
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
@@ -106,6 +105,9 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
dataChan <- data
|
dataChan <- data
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysLog("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
id := helper.GetResponseID(c)
|
id := helper.GetResponseID(c)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -12,6 +11,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
@@ -397,7 +397,7 @@ func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback f
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 读取流式响应
|
// 读取流式响应
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := helper.NewStreamScanner(response.Body)
|
||||||
successful := false
|
successful := false
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -70,7 +69,7 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
defer service.CloseResponseBodyGracefully(resp)
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
var model = info.UpstreamModelName
|
var model = info.UpstreamModelName
|
||||||
var responseId = common.GetUUID()
|
var responseId = common.GetUUID()
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
|||||||
|
|
||||||
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
|
|||||||
|
|
||||||
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := helper.NewStreamScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
metaChan := make(chan string)
|
metaChan := make(chan string)
|
||||||
@@ -180,6 +180,9 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysLog("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|||||||
@@ -34,6 +34,12 @@ func getScannerBufferSize() int {
|
|||||||
return DefaultMaxScannerBufferSize
|
return DefaultMaxScannerBufferSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewStreamScanner(reader io.Reader) *bufio.Scanner {
|
||||||
|
scanner := bufio.NewScanner(reader)
|
||||||
|
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
|
||||||
|
return scanner
|
||||||
|
}
|
||||||
|
|
||||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
|
||||||
|
|
||||||
if resp == nil || dataHandler == nil {
|
if resp == nil || dataHandler == nil {
|
||||||
@@ -54,7 +60,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
||||||
scanner = bufio.NewScanner(resp.Body)
|
scanner = NewStreamScanner(resp.Body)
|
||||||
ticker = time.NewTicker(streamingTimeout)
|
ticker = time.NewTicker(streamingTimeout)
|
||||||
pingTicker *time.Ticker
|
pingTicker *time.Ticker
|
||||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||||
@@ -104,7 +110,6 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
close(stopChan)
|
close(stopChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
|
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
SetEventStreamHeaders(c)
|
SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package helper
|
package helper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -81,6 +82,22 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
|||||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewStreamScanner_AllowsLargeStreamLine(t *testing.T) {
|
||||||
|
oldBufferMB := constant.StreamScannerMaxBufferMB
|
||||||
|
constant.StreamScannerMaxBufferMB = 1
|
||||||
|
t.Cleanup(func() {
|
||||||
|
constant.StreamScannerMaxBufferMB = oldBufferMB
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := strings.Repeat("x", 128<<10)
|
||||||
|
scanner := NewStreamScanner(strings.NewReader("data: " + payload + "\n"))
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
require.True(t, scanner.Scan())
|
||||||
|
assert.Equal(t, "data: "+payload, scanner.Text())
|
||||||
|
require.NoError(t, scanner.Err())
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user