feat: multi-feature update
This commit is contained in:
@@ -34,6 +34,7 @@ type CustomOAuthProviderResponse struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
@@ -64,6 +65,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
|
||||
EmailField: p.EmailField,
|
||||
WellKnown: p.WellKnown,
|
||||
AuthStyle: p.AuthStyle,
|
||||
PKCEEnabled: p.PKCEEnabled,
|
||||
AccessPolicy: p.AccessPolicy,
|
||||
AccessDeniedMessage: p.AccessDeniedMessage,
|
||||
}
|
||||
@@ -129,6 +131,7 @@ type CreateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
@@ -247,6 +250,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
EmailField: req.EmailField,
|
||||
WellKnown: req.WellKnown,
|
||||
AuthStyle: req.AuthStyle,
|
||||
PKCEEnabled: req.PKCEEnabled,
|
||||
AccessPolicy: req.AccessPolicy,
|
||||
AccessDeniedMessage: req.AccessDeniedMessage,
|
||||
}
|
||||
@@ -284,6 +288,7 @@ type UpdateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
|
||||
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
|
||||
PKCEEnabled *bool `json:"pkce_enabled"` // Optional: if nil, keep existing
|
||||
AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
|
||||
AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
|
||||
}
|
||||
@@ -374,6 +379,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.AuthStyle != nil {
|
||||
provider.AuthStyle = *req.AuthStyle
|
||||
}
|
||||
if req.PKCEEnabled != nil {
|
||||
provider.PKCEEnabled = *req.PKCEEnabled
|
||||
}
|
||||
if req.AccessPolicy != nil {
|
||||
provider.AccessPolicy = *req.AccessPolicy
|
||||
}
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const hhhlMisskeyHost = "https://dc.hhhl.cc"
|
||||
|
||||
// HHHLAuthorize adapts Misskey MiAuth to the OAuth authorization endpoint shape.
|
||||
func HHHLAuthorize(c *gin.Context) {
|
||||
redirectURI := strings.TrimSpace(c.Query("redirect_uri"))
|
||||
if redirectURI == "" {
|
||||
c.String(http.StatusBadRequest, "missing redirect_uri")
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
sessionID := uuid.NewString()
|
||||
callbackURL := fmt.Sprintf(
|
||||
"%s/api/hhhl/callback?r=%s&s=%s&sid=%s",
|
||||
strings.TrimRight(system_setting.ServerAddress, "/"),
|
||||
url.QueryEscape(redirectURI),
|
||||
url.QueryEscape(state),
|
||||
url.QueryEscape(sessionID),
|
||||
)
|
||||
|
||||
miAuthURL := fmt.Sprintf(
|
||||
"%s/miauth/%s?name=NewAPI%%E7%%99%%BB%%E5%%BD%%95&callback=%s&permission=read:account",
|
||||
hhhlMisskeyHost,
|
||||
url.PathEscape(sessionID),
|
||||
url.QueryEscape(callbackURL),
|
||||
)
|
||||
c.Redirect(http.StatusFound, miAuthURL)
|
||||
}
|
||||
|
||||
// HHHLCallback returns the MiAuth session id as an OAuth authorization code.
|
||||
// Wrap in pkce.{base64json} format so the generic OAuth provider forwards it correctly.
|
||||
func HHHLCallback(c *gin.Context) {
|
||||
redirectURI := strings.TrimSpace(c.Query("r"))
|
||||
sessionID := strings.TrimSpace(c.Query("sid"))
|
||||
if redirectURI == "" || sessionID == "" {
|
||||
c.String(http.StatusBadRequest, "invalid callback")
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(redirectURI)
|
||||
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
|
||||
c.String(http.StatusBadRequest, "invalid redirect_uri")
|
||||
return
|
||||
}
|
||||
codePayload, _ := common.Marshal(map[string]string{"token": sessionID})
|
||||
code := "pkce." + base64.RawURLEncoding.EncodeToString(codePayload)
|
||||
query := targetURL.Query()
|
||||
query.Set("code", code)
|
||||
query.Set("state", c.Query("s"))
|
||||
targetURL.RawQuery = query.Encode()
|
||||
c.Redirect(http.StatusFound, targetURL.String())
|
||||
}
|
||||
|
||||
// HHHLToken exchanges a MiAuth session id for a Misskey access token.
|
||||
func HHHLToken(c *gin.Context) {
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
if code == "" {
|
||||
if err := c.Request.ParseForm(); err == nil {
|
||||
code = strings.TrimSpace(c.Request.Form.Get("code"))
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
var payload struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err == nil {
|
||||
code = strings.TrimSpace(payload.Code)
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "Missing code"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := code
|
||||
if strings.HasPrefix(code, "pkce.") {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(code[5:])
|
||||
if err == nil {
|
||||
var pkceData struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if jsonErr := common.Unmarshal(decoded, &pkceData); jsonErr == nil && pkceData.Token != "" {
|
||||
sessionID = pkceData.Token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := common.Marshal(gin.H{})
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("%s/api/miauth/%s/check", hhhlMisskeyHost, url.PathEscape(sessionID)),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
|
||||
client := http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "MiAuth check request failed"})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var tokenData struct {
|
||||
OK bool `json:"ok"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := common.Unmarshal(respBody, &tokenData); err != nil || !tokenData.OK || tokenData.Token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "Failed to validate MiAuth session"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"access_token": tokenData.Token, "token_type": "Bearer"})
|
||||
}
|
||||
|
||||
// HHHLUserInfo adapts Misskey /api/i to an OIDC-like userinfo response.
|
||||
func HHHLUserInfo(c *gin.Context) {
|
||||
token := strings.TrimSpace(c.Query("access_token"))
|
||||
if token == "" {
|
||||
token = strings.TrimSpace(strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer "))
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_request", "error_description": "Missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
body, err := common.Marshal(gin.H{"i": token})
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, hhhlMisskeyHost+"/api/i", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
|
||||
client := http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token", "error_description": "Failed to fetch user info"})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var userData struct {
|
||||
Id string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := common.Unmarshal(respBody, &userData); err != nil || userData.Id == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
|
||||
return
|
||||
}
|
||||
if userData.Name == "" {
|
||||
userData.Name = userData.Username
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sub": userData.Id,
|
||||
"preferred_username": userData.Username,
|
||||
"name": userData.Name,
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -18,6 +19,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
)
|
||||
|
||||
func TestStatus(c *gin.Context) {
|
||||
@@ -144,6 +147,7 @@ func GetStatus(c *gin.Context) {
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
}
|
||||
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
|
||||
for _, p := range customProviders {
|
||||
@@ -156,6 +160,7 @@ func GetStatus(c *gin.Context) {
|
||||
ClientId: config.ClientId,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint,
|
||||
Scopes: config.Scopes,
|
||||
PKCEEnabled: config.PKCEEnabled,
|
||||
})
|
||||
}
|
||||
data["custom_oauth_providers"] = providersInfo
|
||||
@@ -231,6 +236,49 @@ func GetHomePageContent(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomeStats(c *gin.Context) {
|
||||
var cpuUsage float64
|
||||
if percents, err := cpu.Percent(150*time.Millisecond, false); err == nil && len(percents) > 0 {
|
||||
cpuUsage = percents[0]
|
||||
} else {
|
||||
cpuUsage = common.GetSystemStatus().CPUUsage
|
||||
}
|
||||
|
||||
var memoryTotal uint64
|
||||
var memoryUsed uint64
|
||||
var memoryUsage float64
|
||||
if memInfo, err := mem.VirtualMemory(); err == nil {
|
||||
memoryTotal = memInfo.Total
|
||||
memoryUsed = memInfo.Used
|
||||
memoryUsage = memInfo.UsedPercent
|
||||
} else {
|
||||
memoryUsage = common.GetSystemStatus().MemoryUsage
|
||||
}
|
||||
|
||||
totalTokens, err := model.SumTotalConsumeTokens()
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), "failed to query home stats token usage: "+err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "查询首页统计失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"cpu_usage": cpuUsage,
|
||||
"memory_usage": memoryUsage,
|
||||
"memory_total": memoryTotal,
|
||||
"memory_used": memoryUsed,
|
||||
"total_tokens": totalTokens,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
|
||||
@@ -90,6 +90,10 @@ func HandleOAuth(c *gin.Context) {
|
||||
|
||||
// 5. Exchange code for token
|
||||
code := c.Query("code")
|
||||
// Pass PKCE code_verifier to context if present
|
||||
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
|
||||
c.Set("pkce_code_verifier", codeVerifier)
|
||||
}
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
@@ -136,6 +140,10 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
|
||||
// Exchange code for token
|
||||
code := c.Query("code")
|
||||
// Pass PKCE code_verifier to context if present
|
||||
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
|
||||
c.Set("pkce_code_verifier", codeVerifier)
|
||||
}
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
|
||||
Reference in New Issue
Block a user