feat: multi-feature update

This commit is contained in:
chaos
2026-06-15 06:16:16 +08:00
parent 6f415428d3
commit 04d30f9dd1
58 changed files with 4610 additions and 419 deletions
+8
View File
@@ -34,6 +34,7 @@ type CustomOAuthProviderResponse struct {
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
PKCEEnabled bool `json:"pkce_enabled"`
AccessPolicy string `json:"access_policy"`
AccessDeniedMessage string `json:"access_denied_message"`
}
@@ -64,6 +65,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
EmailField: p.EmailField,
WellKnown: p.WellKnown,
AuthStyle: p.AuthStyle,
PKCEEnabled: p.PKCEEnabled,
AccessPolicy: p.AccessPolicy,
AccessDeniedMessage: p.AccessDeniedMessage,
}
@@ -129,6 +131,7 @@ type CreateCustomOAuthProviderRequest struct {
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
PKCEEnabled bool `json:"pkce_enabled"`
AccessPolicy string `json:"access_policy"`
AccessDeniedMessage string `json:"access_denied_message"`
}
@@ -247,6 +250,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
EmailField: req.EmailField,
WellKnown: req.WellKnown,
AuthStyle: req.AuthStyle,
PKCEEnabled: req.PKCEEnabled,
AccessPolicy: req.AccessPolicy,
AccessDeniedMessage: req.AccessDeniedMessage,
}
@@ -284,6 +288,7 @@ type UpdateCustomOAuthProviderRequest struct {
EmailField string `json:"email_field"`
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
PKCEEnabled *bool `json:"pkce_enabled"` // Optional: if nil, keep existing
AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
}
@@ -374,6 +379,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.AuthStyle != nil {
provider.AuthStyle = *req.AuthStyle
}
if req.PKCEEnabled != nil {
provider.PKCEEnabled = *req.PKCEEnabled
}
if req.AccessPolicy != nil {
provider.AccessPolicy = *req.AccessPolicy
}
+206
View File
@@ -0,0 +1,206 @@
package controller
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const hhhlMisskeyHost = "https://dc.hhhl.cc"
// HHHLAuthorize adapts Misskey MiAuth to the OAuth authorization endpoint shape.
func HHHLAuthorize(c *gin.Context) {
redirectURI := strings.TrimSpace(c.Query("redirect_uri"))
if redirectURI == "" {
c.String(http.StatusBadRequest, "missing redirect_uri")
return
}
state := c.Query("state")
sessionID := uuid.NewString()
callbackURL := fmt.Sprintf(
"%s/api/hhhl/callback?r=%s&s=%s&sid=%s",
strings.TrimRight(system_setting.ServerAddress, "/"),
url.QueryEscape(redirectURI),
url.QueryEscape(state),
url.QueryEscape(sessionID),
)
miAuthURL := fmt.Sprintf(
"%s/miauth/%s?name=NewAPI%%E7%%99%%BB%%E5%%BD%%95&callback=%s&permission=read:account",
hhhlMisskeyHost,
url.PathEscape(sessionID),
url.QueryEscape(callbackURL),
)
c.Redirect(http.StatusFound, miAuthURL)
}
// HHHLCallback returns the MiAuth session id as an OAuth authorization code.
// Wrap in pkce.{base64json} format so the generic OAuth provider forwards it correctly.
func HHHLCallback(c *gin.Context) {
redirectURI := strings.TrimSpace(c.Query("r"))
sessionID := strings.TrimSpace(c.Query("sid"))
if redirectURI == "" || sessionID == "" {
c.String(http.StatusBadRequest, "invalid callback")
return
}
targetURL, err := url.Parse(redirectURI)
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
c.String(http.StatusBadRequest, "invalid redirect_uri")
return
}
codePayload, _ := common.Marshal(map[string]string{"token": sessionID})
code := "pkce." + base64.RawURLEncoding.EncodeToString(codePayload)
query := targetURL.Query()
query.Set("code", code)
query.Set("state", c.Query("s"))
targetURL.RawQuery = query.Encode()
c.Redirect(http.StatusFound, targetURL.String())
}
// HHHLToken exchanges a MiAuth session id for a Misskey access token.
func HHHLToken(c *gin.Context) {
code := strings.TrimSpace(c.Query("code"))
if code == "" {
if err := c.Request.ParseForm(); err == nil {
code = strings.TrimSpace(c.Request.Form.Get("code"))
}
}
if code == "" {
var payload struct {
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&payload); err == nil {
code = strings.TrimSpace(payload.Code)
}
}
if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "Missing code"})
return
}
sessionID := code
if strings.HasPrefix(code, "pkce.") {
decoded, err := base64.RawURLEncoding.DecodeString(code[5:])
if err == nil {
var pkceData struct {
Token string `json:"token"`
}
if jsonErr := common.Unmarshal(decoded, &pkceData); jsonErr == nil && pkceData.Token != "" {
sessionID = pkceData.Token
}
}
}
body, err := common.Marshal(gin.H{})
if err != nil {
common.ApiError(c, err)
return
}
req, err := http.NewRequestWithContext(
c.Request.Context(),
http.MethodPost,
fmt.Sprintf("%s/api/miauth/%s/check", hhhlMisskeyHost, url.PathEscape(sessionID)),
bytes.NewReader(body),
)
if err != nil {
common.ApiError(c, err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
client := http.Client{Timeout: 20 * time.Second}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "MiAuth check request failed"})
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
common.ApiError(c, err)
return
}
var tokenData struct {
OK bool `json:"ok"`
Token string `json:"token"`
}
if err := common.Unmarshal(respBody, &tokenData); err != nil || !tokenData.OK || tokenData.Token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "Failed to validate MiAuth session"})
return
}
c.JSON(http.StatusOK, gin.H{"access_token": tokenData.Token, "token_type": "Bearer"})
}
// HHHLUserInfo adapts Misskey /api/i to an OIDC-like userinfo response.
func HHHLUserInfo(c *gin.Context) {
token := strings.TrimSpace(c.Query("access_token"))
if token == "" {
token = strings.TrimSpace(strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer "))
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_request", "error_description": "Missing token"})
return
}
body, err := common.Marshal(gin.H{"i": token})
if err != nil {
common.ApiError(c, err)
return
}
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, hhhlMisskeyHost+"/api/i", bytes.NewReader(body))
if err != nil {
common.ApiError(c, err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
client := http.Client{Timeout: 20 * time.Second}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token", "error_description": "Failed to fetch user info"})
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
common.ApiError(c, err)
return
}
var userData struct {
Id string `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
}
if err := common.Unmarshal(respBody, &userData); err != nil || userData.Id == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
return
}
if userData.Name == "" {
userData.Name = userData.Username
}
c.JSON(http.StatusOK, gin.H{
"sub": userData.Id,
"preferred_username": userData.Username,
"name": userData.Name,
})
}
+48
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
@@ -18,6 +19,8 @@ import (
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
)
func TestStatus(c *gin.Context) {
@@ -144,6 +147,7 @@ func GetStatus(c *gin.Context) {
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
Scopes string `json:"scopes"`
PKCEEnabled bool `json:"pkce_enabled"`
}
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
for _, p := range customProviders {
@@ -156,6 +160,7 @@ func GetStatus(c *gin.Context) {
ClientId: config.ClientId,
AuthorizationEndpoint: config.AuthorizationEndpoint,
Scopes: config.Scopes,
PKCEEnabled: config.PKCEEnabled,
})
}
data["custom_oauth_providers"] = providersInfo
@@ -231,6 +236,49 @@ func GetHomePageContent(c *gin.Context) {
return
}
func GetHomeStats(c *gin.Context) {
var cpuUsage float64
if percents, err := cpu.Percent(150*time.Millisecond, false); err == nil && len(percents) > 0 {
cpuUsage = percents[0]
} else {
cpuUsage = common.GetSystemStatus().CPUUsage
}
var memoryTotal uint64
var memoryUsed uint64
var memoryUsage float64
if memInfo, err := mem.VirtualMemory(); err == nil {
memoryTotal = memInfo.Total
memoryUsed = memInfo.Used
memoryUsage = memInfo.UsedPercent
} else {
memoryUsage = common.GetSystemStatus().MemoryUsage
}
totalTokens, err := model.SumTotalConsumeTokens()
if err != nil {
logger.LogError(c.Request.Context(), "failed to query home stats token usage: "+err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "查询首页统计失败",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"cpu_usage": cpuUsage,
"memory_usage": memoryUsage,
"memory_total": memoryTotal,
"memory_used": memoryUsed,
"total_tokens": totalTokens,
},
})
return
}
func SendEmailVerification(c *gin.Context) {
email := c.Query("email")
if err := common.Validate.Var(email, "required,email"); err != nil {
+8
View File
@@ -90,6 +90,10 @@ func HandleOAuth(c *gin.Context) {
// 5. Exchange code for token
code := c.Query("code")
// Pass PKCE code_verifier to context if present
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
c.Set("pkce_code_verifier", codeVerifier)
}
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)
@@ -136,6 +140,10 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
// Exchange code for token
code := c.Query("code")
// Pass PKCE code_verifier to context if present
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
c.Set("pkce_code_verifier", codeVerifier)
}
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)