Compare commits
19 Commits
1e6f31b235
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 346cf0e4a6 | |||
| 04d30f9dd1 | |||
| 6f415428d3 | |||
| 59a93cf5c7 | |||
| 867d8acfc3 | |||
| 30d3a3a5f7 | |||
| d2576ddcd3 | |||
| 4ca47ee236 | |||
| 16dd7237c0 | |||
| 1915344838 | |||
| 15ff8e0268 | |||
| a1c82841b5 | |||
| 2eaa943d9f | |||
| 7a5348caa3 | |||
| f5753a2b31 | |||
| e8c36762fd | |||
| e2dbd02cbb | |||
| c8d3768087 | |||
| 979aeceb5c |
@@ -10,8 +10,10 @@ build
|
||||
logs
|
||||
web/default/dist
|
||||
web/classic/dist
|
||||
web/image-gen/dist
|
||||
web/node_modules
|
||||
web/dist
|
||||
electron/dist
|
||||
.env
|
||||
one-api
|
||||
new-api
|
||||
|
||||
@@ -2,136 +2,186 @@
|
||||
|
||||
## Overview
|
||||
|
||||
This is an AI API gateway/proxy built with Go. It aggregates 40+ upstream AI providers (OpenAI, Claude, Gemini, Azure, AWS Bedrock, etc.) behind a unified API, with user management, billing, rate limiting, and an admin dashboard.
|
||||
AI API gateway/proxy (Go) aggregating 40+ upstream AI providers behind a unified API, with user management, billing, rate limiting, and a React admin dashboard.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Backend**: Go 1.22+, Gin web framework, GORM v2 ORM
|
||||
- **Frontend**: React 19, TypeScript, Rsbuild, Base UI, Tailwind CSS
|
||||
- **Databases**: SQLite, MySQL, PostgreSQL (all three must be supported)
|
||||
- **Cache**: Redis (go-redis) + in-memory cache
|
||||
- **Auth**: JWT, WebAuthn/Passkeys, OAuth (GitHub, Discord, OIDC, etc.)
|
||||
- **Frontend package manager**: Bun (preferred over npm/yarn/pnpm)
|
||||
- **Backend**: Go 1.25+, Gin, GORM v2, testify
|
||||
- **Frontend**: Two themes — `web/default/` (React 19, Rsbuild, Base UI, Tailwind CSS 4, TanStack Router) and `web/classic/` (React 18, Vite, Semi Design). Default is the primary.
|
||||
- **Databases**: SQLite, MySQL >= 5.7.8, PostgreSQL >= 9.6 (all three supported simultaneously)
|
||||
- **Cache**: Redis (go-redis) + in-memory
|
||||
- **Auth**: JWT, WebAuthn/Passkeys, OAuth (GitHub, Discord, OIDC)
|
||||
- **Desktop**: Electron app at `electron/`
|
||||
|
||||
## Architecture
|
||||
|
||||
Layered architecture: Router -> Controller -> Service -> Model
|
||||
Layered: `router/` → `controller/` → `service/` → `model/`
|
||||
|
||||
```
|
||||
router/ — HTTP routing (API, relay, dashboard, web)
|
||||
controller/ — Request handlers
|
||||
service/ — Business logic
|
||||
model/ — Data models and DB access (GORM)
|
||||
relay/ — AI API relay/proxy with provider adapters
|
||||
relay/channel/ — Provider-specific adapters (openai/, claude/, gemini/, aws/, etc.)
|
||||
middleware/ — Auth, rate limiting, CORS, logging, distribution
|
||||
setting/ — Configuration management (ratio, model, operation, system, performance)
|
||||
common/ — Shared utilities (JSON, crypto, Redis, env, rate-limit, etc.)
|
||||
dto/ — Data transfer objects (request/response structs)
|
||||
constant/ — Constants (API types, channel types, context keys)
|
||||
types/ — Type definitions (relay formats, file sources, errors)
|
||||
i18n/ — Backend internationalization (go-i18n, en/zh)
|
||||
oauth/ — OAuth provider implementations
|
||||
pkg/ — Internal packages (cachex, ionet)
|
||||
web/ — Frontend themes container
|
||||
web/default/ — Default frontend (React 19, Rsbuild, Base UI, Tailwind)
|
||||
web/classic/ — Classic frontend (React 18, Vite, Semi Design)
|
||||
web/default/src/i18n/ — Frontend internationalization (i18next, zh/en/fr/ru/ja/vi)
|
||||
router/ — HTTP routing (api, relay, dashboard, web)
|
||||
controller/ — Request handlers
|
||||
service/ — Business logic
|
||||
model/ — Data models and DB access (GORM), auto-migrations
|
||||
relay/ — AI relay/proxy with 40+ provider adapters in relay/channel/
|
||||
middleware/ — Auth, rate limiting, CORS, logging, distribution
|
||||
setting/ — Config management (ratio, model, operation, system, performance)
|
||||
common/ — Shared utilities (JSON, crypto, Redis, env, rate-limit, etc.)
|
||||
dto/ — Request/response DTOs
|
||||
constant/ — API types, channel types, context keys
|
||||
types/ — Relay format types, file sources, errors
|
||||
i18n/ — Backend i18n (go-i18n, 3 locales: en, zh-CN, zh-TW)
|
||||
oauth/ — OAuth provider implementations
|
||||
pkg/ — Internal packages: cachex, ionet, billingexpr, perf_metrics
|
||||
web/ — Frontend themes: web/default/, web/classic/
|
||||
```
|
||||
|
||||
## Internationalization (i18n)
|
||||
## Key Conventions
|
||||
|
||||
### Backend (`i18n/`)
|
||||
- Library: `nicksnyder/go-i18n/v2`
|
||||
- Languages: en, zh
|
||||
### 1. JSON — Use `common/json.go` wrappers
|
||||
|
||||
### Frontend (`web/default/src/i18n/`)
|
||||
- Library: `i18next` + `react-i18next` + `i18next-browser-languagedetector`
|
||||
- Languages: en (base), zh (fallback), fr, ru, ja, vi
|
||||
- Translation files: `web/default/src/i18n/locales/{lang}.json` — flat JSON, keys are English source strings
|
||||
- Usage: `useTranslation()` hook, call `t('English key')` in components
|
||||
- CLI tools: `bun run i18n:sync` (from `web/default/`)
|
||||
All marshal/unmarshal MUST use `common.Marshal`, `common.Unmarshal`, etc. Do NOT call `encoding/json` directly in business code. Type definitions from `encoding/json` (e.g. `json.RawMessage`) are still fine to reference.
|
||||
|
||||
## Rules
|
||||
### 2. Cross-DB Compatibility (SQLite, MySQL, PostgreSQL)
|
||||
|
||||
### Rule 1: JSON Package — Use `common/json.go`
|
||||
- Prefer GORM methods over raw SQL.
|
||||
- Use `commonGroupCol`, `commonKeyCol`, `commonTrueVal`, `commonFalseVal` from `model/main.go` for reserved words and boolean literals.
|
||||
- Branch DB-specific logic with `common.UsingPostgreSQL`, `common.UsingSQLite`, `common.UsingMySQL`.
|
||||
- Forbidden without cross-DB fallback: MySQL-only `GROUP_CONCAT`, PostgreSQL `@>`/`JSONB` operators, `ALTER COLUMN` on SQLite, DB-specific column types (use `TEXT` for JSON).
|
||||
- Migrations must pass on all three DBs.
|
||||
|
||||
All JSON marshal/unmarshal operations MUST use the wrapper functions in `common/json.go`:
|
||||
### 3. Frontend — Bun required
|
||||
|
||||
- `common.Marshal(v any) ([]byte, error)`
|
||||
- `common.Unmarshal(data []byte, v any) error`
|
||||
- `common.UnmarshalJsonStr(data string, v any) error`
|
||||
- `common.DecodeJson(reader io.Reader, v any) error`
|
||||
- `common.GetJsonType(data json.RawMessage) string`
|
||||
Use `bun` (not npm/yarn/pnpm) for `web/default/`:
|
||||
- `bun install` / `bun run dev` / `bun run build`
|
||||
- See `web/default/AGENTS.md` for detailed frontend conventions (i18n, components, forms, routing, etc.)
|
||||
|
||||
Do NOT directly import or call `encoding/json` in business code. These wrappers exist for consistency and future extensibility (e.g., swapping to a faster JSON library).
|
||||
### 4. New Channel — StreamOptions
|
||||
|
||||
Note: `json.RawMessage`, `json.Number`, and other type definitions from `encoding/json` may still be referenced as types, but actual marshal/unmarshal calls must go through `common.*`.
|
||||
When adding a channel, check if the provider supports `StreamOptions`. If so, add the channel type to `streamSupportedChannels` in `relay/common/relay_info.go:320`.
|
||||
|
||||
### Rule 2: Database Compatibility — SQLite, MySQL >= 5.7.8, PostgreSQL >= 9.6
|
||||
### 5. Protected Identity — DO NOT Modify
|
||||
|
||||
All database code MUST be fully compatible with all three databases simultaneously.
|
||||
Do NOT remove, rename, or replace any reference to **new-api** (project) or **QuantumNous** (organization). This includes README, HTML titles, Go module paths, Docker images, package metadata, comments, and deployment configs.
|
||||
|
||||
**Use GORM abstractions:**
|
||||
- Prefer GORM methods (`Create`, `Find`, `Where`, `Updates`, etc.) over raw SQL.
|
||||
- Let GORM handle primary key generation — do not use `AUTO_INCREMENT` or `SERIAL` directly.
|
||||
### 6. Upstream DTOs — Pointer Types for Zero Values
|
||||
|
||||
**When raw SQL is unavoidable:**
|
||||
- Column quoting differs: PostgreSQL uses `"column"`, MySQL/SQLite uses `` `column` ``.
|
||||
- Use `commonGroupCol`, `commonKeyCol` variables from `model/main.go` for reserved-word columns like `group` and `key`.
|
||||
- Boolean values differ: PostgreSQL uses `true`/`false`, MySQL/SQLite uses `1`/`0`. Use `commonTrueVal`/`commonFalseVal`.
|
||||
- Use `common.UsingPostgreSQL`, `common.UsingSQLite`, `common.UsingMySQL` flags to branch DB-specific logic.
|
||||
Optional scalar fields in request structs that are parsed from client JSON and re-marshaled to upstream providers MUST use pointer types (`*int`, `*float64`, `*bool`, etc.) with `omitempty`. Non-`nil` pointer = preserve zero value. Non-pointer scalars with `omitempty` silently drop zeros.
|
||||
|
||||
**Forbidden without cross-DB fallback:**
|
||||
- MySQL-only functions (e.g., `GROUP_CONCAT` without PostgreSQL `STRING_AGG` equivalent)
|
||||
- PostgreSQL-only operators (e.g., `@>`, `?`, `JSONB` operators)
|
||||
- `ALTER COLUMN` in SQLite (unsupported — use column-add workaround)
|
||||
- Database-specific column types without fallback — use `TEXT` instead of `JSONB` for JSON storage
|
||||
### 7. Billing Expressions — Read `pkg/billingexpr/expr.md`
|
||||
|
||||
**Migrations:**
|
||||
- Ensure all migrations work on all three databases.
|
||||
- For SQLite, use `ALTER TABLE ... ADD COLUMN` instead of `ALTER COLUMN` (see `model/main.go` for patterns).
|
||||
When working on tiered/dynamic billing, read `pkg/billingexpr/expr.md` first. It documents the expression language, system architecture, token normalization, and settlement rules.
|
||||
|
||||
### Rule 3: Frontend — Prefer Bun
|
||||
## Development
|
||||
|
||||
Use `bun` as the preferred package manager and script runner for the frontend (`web/default/` directory):
|
||||
- `bun install` for dependency installation
|
||||
- `bun run dev` for development server
|
||||
- `bun run build` for production build
|
||||
- `bun run i18n:*` for i18n tooling
|
||||
### Run Backend
|
||||
```sh
|
||||
cp .env.example .env # edit as needed
|
||||
go run main.go # starts on :3000
|
||||
```
|
||||
|
||||
### Rule 4: New Channel StreamOptions Support
|
||||
### Build (matches CI/Docker)
|
||||
The VERSION file is created by CI (from git tag). For a manual build:
|
||||
```sh
|
||||
echo "dev" > VERSION
|
||||
(cd web/default && bun install && DISABLE_ESLINT_PLUGIN='true' bun run build)
|
||||
go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
```
|
||||
|
||||
When implementing a new channel:
|
||||
- Confirm whether the provider supports `StreamOptions`.
|
||||
- If supported, add the channel to `streamSupportedChannels`.
|
||||
### Run Backend Tests
|
||||
```sh
|
||||
go test ./... # all packages
|
||||
go test ./pkg/billingexpr/... # single package
|
||||
```
|
||||
|
||||
### Rule 5: Protected Project Information — DO NOT Modify or Delete
|
||||
### Run Frontend Dev
|
||||
```sh
|
||||
cd web/default && bun install && bun run dev
|
||||
# proxies /api, /mj, /pg -> http://localhost:3000 (configurable via VITE_REACT_APP_SERVER_URL)
|
||||
```
|
||||
|
||||
The following project-related information is **strictly protected** and MUST NOT be modified, deleted, replaced, or removed under any circumstances:
|
||||
### Other Scripts
|
||||
```sh
|
||||
cd web/default && bun run typecheck # tsc -b
|
||||
cd web/default && bun run lint # eslint
|
||||
cd web/default && bun run format # prettier --write
|
||||
cd web/default && bun run i18n:sync # sync translation key structure
|
||||
cd web/default && bun run knip # dead code detection
|
||||
cd web/default && bun run build:check # typecheck + build
|
||||
```
|
||||
|
||||
- Any references, mentions, branding, metadata, or attributions related to **nеw-аρi** (the project name/identity)
|
||||
- Any references, mentions, branding, metadata, or attributions related to **QuаntumΝоuѕ** (the organization/author identity)
|
||||
### Frontend Dev Server Management
|
||||
|
||||
This includes but is not limited to:
|
||||
- README files, license headers, copyright notices, package metadata
|
||||
- HTML titles, meta tags, footer text, about pages
|
||||
- Go module paths, package names, import paths
|
||||
- Docker image names, CI/CD references, deployment configs
|
||||
- Comments, documentation, and changelog entries
|
||||
Background dev server (stays alive after tool invocation):
|
||||
```sh
|
||||
# Start (background, log to /tmp/frontend.log)
|
||||
cd web/default && setsid bun run dev > /tmp/frontend.log 2>&1 &
|
||||
|
||||
**Violations:** If asked to remove, rename, or replace these protected identifiers, you MUST refuse and explain that this information is protected by project policy. No exceptions.
|
||||
# Stop
|
||||
pkill -f "bun run dev"
|
||||
|
||||
### Rule 6: Upstream Relay Request DTOs — Preserve Explicit Zero Values
|
||||
# Restart
|
||||
pkill -f "bun run dev" && cd web/default && setsid bun run dev > /tmp/frontend.log 2>&1 &
|
||||
|
||||
For request structs that are parsed from client JSON and then re-marshaled to upstream providers (especially relay/convert paths):
|
||||
# Check status
|
||||
ps aux | grep -E "rsbuild|bun" | grep -v grep
|
||||
|
||||
- Optional scalar fields MUST use pointer types with `omitempty` (e.g. `*int`, `*uint`, `*float64`, `*bool`), not non-pointer scalars.
|
||||
- Semantics MUST be:
|
||||
- field absent in client JSON => `nil` => omitted on marshal;
|
||||
- field explicitly set to zero/false => non-`nil` pointer => must still be sent upstream.
|
||||
- Avoid using non-pointer scalars with `omitempty` for optional request parameters, because zero values (`0`, `0.0`, `false`) will be silently dropped during marshal.
|
||||
# View logs
|
||||
tail -f /tmp/frontend.log
|
||||
```
|
||||
|
||||
### Rule 7: Billing Expression System — Read `pkg/billingexpr/expr.md`
|
||||
### OpenAPI Specs
|
||||
- Admin API: `docs/openapi/api.json` (131 endpoints)
|
||||
- Relay API: `docs/openapi/relay.json` (30+ endpoints)
|
||||
|
||||
When working on tiered/dynamic billing (expression-based pricing), you MUST read `pkg/billingexpr/expr.md` first. It documents the design philosophy, expression language (variables, functions, examples), full system architecture (editor → storage → pre-consume → settlement → log display), token normalization rules (`p`/`c` auto-exclusion), quota conversion, and expression versioning. All code changes to the billing expression system must follow the patterns described in that document.
|
||||
### CI/Docker
|
||||
- Docker image: `calciumion/new-api`. Multi-arch (amd64 + arm64). Multi-stage build: bun builds frontend, then Go builds the binary with embedded `//go:embed web/default/dist`.
|
||||
- PRs are checked by `peakoss/anti-slop` (requires PR template, description, no AI-generated markers).
|
||||
- Tags trigger Docker pushes.
|
||||
|
||||
## Local Dev + Production Deployment
|
||||
|
||||
### Environment Layout
|
||||
|
||||
| Item | Dev (chaos user) | Prod (www user) |
|
||||
|------|------------------|-----------------|
|
||||
| Directory | `/home/chaos/new-api/` | `/home/www/new-api-prod/` |
|
||||
| Port | `localhost:3000` (API) / `localhost:5173` (frontend hot-reload) | `localhost:3001` |
|
||||
| Config | `docker-compose.dev.yml` | `/home/www/new-api-prod/docker-compose.prod.yml` |
|
||||
| Data | Docker volume `dev_data` | `/home/www/new-api-prod/data/` |
|
||||
|
||||
### Git Remotes
|
||||
|
||||
- `origin` → `https://git.nomsg.cn/chaos/chaos-api.git` (fork, push here)
|
||||
- `upstream` → `https://github.com/QuantumNous/new-api.git` (official, pull from here)
|
||||
|
||||
### Daily Workflow
|
||||
|
||||
```sh
|
||||
# Sync upstream (auto-triggers deploy via post-merge hook)
|
||||
git fetch upstream && git merge upstream/main
|
||||
|
||||
# Manual deploy
|
||||
./deploy.sh
|
||||
```
|
||||
|
||||
### Deploy Script (`deploy.sh`)
|
||||
|
||||
Located at project root. Does three things:
|
||||
1. Builds frontend (`web/default/`) with bun
|
||||
2. Builds Docker image `my-new-api:latest`
|
||||
3. Restarts production containers as `www` user via `sudo -u www docker compose ...`
|
||||
|
||||
### Git Hook
|
||||
|
||||
`.git/hooks/post-merge` automatically runs `./deploy.sh` after `git merge`.
|
||||
|
||||
### Permission Isolation
|
||||
|
||||
- **chaos**: owns code, builds images, triggers deploy via sudoers whitelist
|
||||
- **www**: owns production data dir, runs production containers
|
||||
- sudoers rule: `/etc/sudoers.d/chaos-deploy` — chaos can only run `docker compose -f /home/www/new-api-prod/docker-compose.prod.yml *` as www
|
||||
|
||||
### Production Secrets
|
||||
|
||||
Stored in `/home/www/new-api-prod/.env` (owned by www, mode 600). Contains:
|
||||
- `DB_PASS` — PostgreSQL password
|
||||
- `REDIS_PASS` — Redis password
|
||||
- `SESSION_SECRET` — multi-node session secret
|
||||
|
||||
+12
-1
@@ -20,8 +20,18 @@ COPY ./web/classic ./classic
|
||||
COPY ./VERSION /build/VERSION
|
||||
RUN cd classic && VITE_REACT_APP_VERSION=$(cat /build/VERSION) bun run build
|
||||
|
||||
# image-gen: a small Vue 3 + Vite SPA that lives in web/image-gen/.
|
||||
# It uses npm (its own package-lock.json), so we use node:20 instead of bun.
|
||||
FROM node:20-alpine@sha256:49f3aca83b15186f1b7b8b21b06789a73ed1a4f9c4f1a0e3ce4a1ae9e5c8e3f5b AS builder-image-gen
|
||||
|
||||
WORKDIR /build/web/image-gen
|
||||
COPY web/image-gen/package.json web/image-gen/package-lock.json ./
|
||||
RUN npm ci --no-audit --no-fund
|
||||
COPY web/image-gen ./
|
||||
RUN npm run build
|
||||
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
ENV GO111MODULE=on CGO_ENABLED=0 GOPROXY=https://goproxy.cn,direct
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@@ -36,6 +46,7 @@ RUN go mod download
|
||||
COPY . .
|
||||
COPY --from=builder /build/web/default/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/web/classic/dist ./web/classic/dist
|
||||
COPY --from=builder-image-gen /build/web/image-gen/dist ./web/image-gen/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
+3
-2
@@ -16,9 +16,10 @@ RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p web/default/dist web/classic/dist && \
|
||||
RUN mkdir -p web/default/dist web/classic/dist web/image-gen/dist && \
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/default/dist/index.html && \
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/classic/dist/index.html
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/classic/dist/index.html && \
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/image-gen/dist/index.html
|
||||
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
|
||||
@@ -173,6 +173,11 @@ var RelayTimeout int // unit is second
|
||||
var RelayIdleConnTimeout int // unit is second
|
||||
var RelayMaxIdleConns int
|
||||
var RelayMaxIdleConnsPerHost int
|
||||
var RelayResponseHeaderTimeout int // unit is second
|
||||
var RelayTLSHandshakeTimeout int // unit is second
|
||||
var RelayExpectContinueTimeout int // unit is second
|
||||
var RelayForceIPv4 bool
|
||||
var RelayDisableHTTP2 bool
|
||||
|
||||
var GeminiSafetySetting string
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/static"
|
||||
)
|
||||
@@ -16,7 +17,17 @@ type embedFileSystem struct {
|
||||
}
|
||||
|
||||
func (e *embedFileSystem) Exists(prefix string, path string) bool {
|
||||
_, err := e.Open(path)
|
||||
// gin-contrib/static passes the raw URL path (e.g. "/image-gen/assets/x.js")
|
||||
// together with the URL prefix we registered (e.g. "/image-gen"). The
|
||||
// underlying fs.Sub FS only knows about the sub-tree (no prefix), so we
|
||||
// must strip the prefix before asking it whether the file exists. An
|
||||
// empty prefix means "served at /" — nothing to strip.
|
||||
p := strings.TrimPrefix(path, prefix)
|
||||
if p == path {
|
||||
// prefix didn't match — definitely not in this FS
|
||||
return false
|
||||
}
|
||||
_, err := e.Open(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
+7
-2
@@ -105,6 +105,11 @@ func InitEnv() {
|
||||
RelayIdleConnTimeout = GetEnvOrDefault("RELAY_IDLE_CONN_TIMEOUT", 90)
|
||||
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
|
||||
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
|
||||
RelayResponseHeaderTimeout = GetEnvOrDefault("RELAY_RESPONSE_HEADER_TIMEOUT", 60)
|
||||
RelayTLSHandshakeTimeout = GetEnvOrDefault("RELAY_TLS_HANDSHAKE_TIMEOUT", 10)
|
||||
RelayExpectContinueTimeout = GetEnvOrDefault("RELAY_EXPECT_CONTINUE_TIMEOUT", 1)
|
||||
RelayForceIPv4 = GetEnvOrDefaultBool("RELAY_FORCE_IPV4", false)
|
||||
RelayDisableHTTP2 = GetEnvOrDefaultBool("RELAY_DISABLE_HTTP2", false)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
@@ -112,11 +117,11 @@ func InitEnv() {
|
||||
|
||||
// Initialize rate limit variables
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 360)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 120)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
|
||||
|
||||
@@ -34,6 +34,7 @@ type CustomOAuthProviderResponse struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
@@ -64,6 +65,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
|
||||
EmailField: p.EmailField,
|
||||
WellKnown: p.WellKnown,
|
||||
AuthStyle: p.AuthStyle,
|
||||
PKCEEnabled: p.PKCEEnabled,
|
||||
AccessPolicy: p.AccessPolicy,
|
||||
AccessDeniedMessage: p.AccessDeniedMessage,
|
||||
}
|
||||
@@ -129,6 +131,7 @@ type CreateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
@@ -247,6 +250,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
EmailField: req.EmailField,
|
||||
WellKnown: req.WellKnown,
|
||||
AuthStyle: req.AuthStyle,
|
||||
PKCEEnabled: req.PKCEEnabled,
|
||||
AccessPolicy: req.AccessPolicy,
|
||||
AccessDeniedMessage: req.AccessDeniedMessage,
|
||||
}
|
||||
@@ -284,6 +288,7 @@ type UpdateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
|
||||
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
|
||||
PKCEEnabled *bool `json:"pkce_enabled"` // Optional: if nil, keep existing
|
||||
AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
|
||||
AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
|
||||
}
|
||||
@@ -374,6 +379,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.AuthStyle != nil {
|
||||
provider.AuthStyle = *req.AuthStyle
|
||||
}
|
||||
if req.PKCEEnabled != nil {
|
||||
provider.PKCEEnabled = *req.PKCEEnabled
|
||||
}
|
||||
if req.AccessPolicy != nil {
|
||||
provider.AccessPolicy = *req.AccessPolicy
|
||||
}
|
||||
|
||||
@@ -139,6 +139,11 @@ func DiscordOAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.RestoreIfDeleted("discord", c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to restore user %d: %v", user.Id, err))
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
|
||||
@@ -122,12 +122,9 @@ func GitHubOAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
if err := user.RestoreIfDeleted("github", c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to restore user %d: %v", user.Id, err))
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const hhhlMisskeyHost = "https://dc.hhhl.cc"
|
||||
|
||||
// HHHLAuthorize adapts Misskey MiAuth to the OAuth authorization endpoint shape.
|
||||
func HHHLAuthorize(c *gin.Context) {
|
||||
redirectURI := strings.TrimSpace(c.Query("redirect_uri"))
|
||||
if redirectURI == "" {
|
||||
c.String(http.StatusBadRequest, "missing redirect_uri")
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
sessionID := uuid.NewString()
|
||||
callbackURL := fmt.Sprintf(
|
||||
"%s/api/hhhl/callback?r=%s&s=%s&sid=%s",
|
||||
strings.TrimRight(system_setting.ServerAddress, "/"),
|
||||
url.QueryEscape(redirectURI),
|
||||
url.QueryEscape(state),
|
||||
url.QueryEscape(sessionID),
|
||||
)
|
||||
|
||||
miAuthURL := fmt.Sprintf(
|
||||
"%s/miauth/%s?name=NewAPI%%E7%%99%%BB%%E5%%BD%%95&callback=%s&permission=read:account",
|
||||
hhhlMisskeyHost,
|
||||
url.PathEscape(sessionID),
|
||||
url.QueryEscape(callbackURL),
|
||||
)
|
||||
c.Redirect(http.StatusFound, miAuthURL)
|
||||
}
|
||||
|
||||
// HHHLCallback returns the MiAuth session id as an OAuth authorization code.
|
||||
// Wrap in pkce.{base64json} format so the generic OAuth provider forwards it correctly.
|
||||
func HHHLCallback(c *gin.Context) {
|
||||
redirectURI := strings.TrimSpace(c.Query("r"))
|
||||
sessionID := strings.TrimSpace(c.Query("sid"))
|
||||
if redirectURI == "" || sessionID == "" {
|
||||
c.String(http.StatusBadRequest, "invalid callback")
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(redirectURI)
|
||||
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
|
||||
c.String(http.StatusBadRequest, "invalid redirect_uri")
|
||||
return
|
||||
}
|
||||
codePayload, _ := common.Marshal(map[string]string{"token": sessionID})
|
||||
code := "pkce." + base64.RawURLEncoding.EncodeToString(codePayload)
|
||||
query := targetURL.Query()
|
||||
query.Set("code", code)
|
||||
query.Set("state", c.Query("s"))
|
||||
targetURL.RawQuery = query.Encode()
|
||||
c.Redirect(http.StatusFound, targetURL.String())
|
||||
}
|
||||
|
||||
// HHHLToken exchanges a MiAuth session id for a Misskey access token.
|
||||
func HHHLToken(c *gin.Context) {
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
if code == "" {
|
||||
if err := c.Request.ParseForm(); err == nil {
|
||||
code = strings.TrimSpace(c.Request.Form.Get("code"))
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
var payload struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err == nil {
|
||||
code = strings.TrimSpace(payload.Code)
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "Missing code"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := code
|
||||
if strings.HasPrefix(code, "pkce.") {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(code[5:])
|
||||
if err == nil {
|
||||
var pkceData struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if jsonErr := common.Unmarshal(decoded, &pkceData); jsonErr == nil && pkceData.Token != "" {
|
||||
sessionID = pkceData.Token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := common.Marshal(gin.H{})
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("%s/api/miauth/%s/check", hhhlMisskeyHost, url.PathEscape(sessionID)),
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
|
||||
client := http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "MiAuth check request failed"})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var tokenData struct {
|
||||
OK bool `json:"ok"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := common.Unmarshal(respBody, &tokenData); err != nil || !tokenData.OK || tokenData.Token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant", "error_description": "Failed to validate MiAuth session"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"access_token": tokenData.Token, "token_type": "Bearer"})
|
||||
}
|
||||
|
||||
// HHHLUserInfo adapts Misskey /api/i to an OIDC-like userinfo response.
|
||||
func HHHLUserInfo(c *gin.Context) {
|
||||
token := strings.TrimSpace(c.Query("access_token"))
|
||||
if token == "" {
|
||||
token = strings.TrimSpace(strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer "))
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_request", "error_description": "Missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
body, err := common.Marshal(gin.H{"i": token})
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, hhhlMisskeyHost+"/api/i", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
|
||||
client := http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token", "error_description": "Failed to fetch user info"})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var userData struct {
|
||||
Id string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := common.Unmarshal(respBody, &userData); err != nil || userData.Id == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
|
||||
return
|
||||
}
|
||||
if userData.Name == "" {
|
||||
userData.Name = userData.Username
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sub": userData.Id,
|
||||
"preferred_username": userData.Username,
|
||||
"name": userData.Name,
|
||||
})
|
||||
}
|
||||
@@ -212,11 +212,9 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
if err := user.RestoreIfDeleted("linuxdo", c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to restore user %d: %v", user.Id, err))
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -18,6 +19,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
)
|
||||
|
||||
func TestStatus(c *gin.Context) {
|
||||
@@ -144,6 +147,7 @@ func GetStatus(c *gin.Context) {
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
PKCEEnabled bool `json:"pkce_enabled"`
|
||||
}
|
||||
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
|
||||
for _, p := range customProviders {
|
||||
@@ -156,6 +160,7 @@ func GetStatus(c *gin.Context) {
|
||||
ClientId: config.ClientId,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint,
|
||||
Scopes: config.Scopes,
|
||||
PKCEEnabled: config.PKCEEnabled,
|
||||
})
|
||||
}
|
||||
data["custom_oauth_providers"] = providersInfo
|
||||
@@ -231,6 +236,49 @@ func GetHomePageContent(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomeStats(c *gin.Context) {
|
||||
var cpuUsage float64
|
||||
if percents, err := cpu.Percent(150*time.Millisecond, false); err == nil && len(percents) > 0 {
|
||||
cpuUsage = percents[0]
|
||||
} else {
|
||||
cpuUsage = common.GetSystemStatus().CPUUsage
|
||||
}
|
||||
|
||||
var memoryTotal uint64
|
||||
var memoryUsed uint64
|
||||
var memoryUsage float64
|
||||
if memInfo, err := mem.VirtualMemory(); err == nil {
|
||||
memoryTotal = memInfo.Total
|
||||
memoryUsed = memInfo.Used
|
||||
memoryUsage = memInfo.UsedPercent
|
||||
} else {
|
||||
memoryUsage = common.GetSystemStatus().MemoryUsage
|
||||
}
|
||||
|
||||
totalTokens, err := model.SumTotalConsumeTokens()
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), "failed to query home stats token usage: "+err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "查询首页统计失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"cpu_usage": cpuUsage,
|
||||
"memory_usage": memoryUsage,
|
||||
"memory_total": memoryTotal,
|
||||
"memory_used": memoryUsed,
|
||||
"total_tokens": totalTokens,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
|
||||
+15
-3
@@ -90,6 +90,10 @@ func HandleOAuth(c *gin.Context) {
|
||||
|
||||
// 5. Exchange code for token
|
||||
code := c.Query("code")
|
||||
// Pass PKCE code_verifier to context if present
|
||||
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
|
||||
c.Set("pkce_code_verifier", codeVerifier)
|
||||
}
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
@@ -136,6 +140,10 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
|
||||
// Exchange code for token
|
||||
code := c.Query("code")
|
||||
// Pass PKCE code_verifier to context if present
|
||||
if codeVerifier := c.Query("code_verifier"); codeVerifier != "" {
|
||||
c.Set("pkce_code_verifier", codeVerifier)
|
||||
}
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
@@ -205,9 +213,9 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Check if user has been deleted
|
||||
if user.Id == 0 {
|
||||
return nil, &OAuthUserDeletedError{}
|
||||
if err := user.RestoreIfDeleted(provider.GetName(), c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to restore user %d: %s", user.Id, err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
@@ -219,6 +227,10 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := user.RestoreIfDeleted(provider.GetName(), c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to restore user %d: %s", user.Id, err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
if user.Id != 0 {
|
||||
// Found user with legacy ID, migrate to new ID
|
||||
common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
|
||||
|
||||
@@ -141,6 +141,11 @@ func OidcAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.RestoreIfDeleted("oidc", c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to restore user %d: %v", user.Id, err))
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
|
||||
@@ -95,6 +95,11 @@ func TelegramLogin(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.RestoreIfDeleted("telegram", c.ClientIP()); err != nil {
|
||||
common.SysError("failed to restore user: " + err.Error())
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -51,7 +51,7 @@ func Login(c *gin.Context) {
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
err = user.ValidateAndFill()
|
||||
err = user.ValidateAndFill(c.ClientIP())
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, model.ErrDatabase):
|
||||
|
||||
@@ -82,11 +82,9 @@ func WeChatAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
if err := user.RestoreIfDeleted("wechat", c.ClientIP()); err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to restore user %d: %v", user.Id, err))
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
PROD_DIR="/home/www/new-api-prod"
|
||||
COMPOSE_FILE="$PROD_DIR/docker-compose.prod.yml"
|
||||
IMAGE_NAME="my-new-api:latest"
|
||||
PROJECT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
echo "========================================="
|
||||
echo " New API Auto Deploy"
|
||||
echo " $(date '+%Y-%m-%d %H:%M:%S')"
|
||||
echo "========================================="
|
||||
|
||||
# Step 1: Build frontend
|
||||
echo "[1/4] Building web/default..."
|
||||
export PATH="$HOME/.bun/bin:$PATH"
|
||||
cd "$PROJECT_DIR/web"
|
||||
bun install --frozen-lockfile
|
||||
cd "$PROJECT_DIR/web/default"
|
||||
bun run build
|
||||
echo " web/default built."
|
||||
|
||||
echo "[2/4] Building web/image-gen..."
|
||||
cd "$PROJECT_DIR/web/image-gen"
|
||||
if command -v npm >/dev/null 2>&1; then
|
||||
npm ci --no-audit --no-fund
|
||||
npm run build
|
||||
else
|
||||
echo " WARNING: npm not found, image-gen dist not built."
|
||||
echo " Install Node.js 20+ to build the image-gen sub-app."
|
||||
mkdir -p dist
|
||||
echo '<!doctype html><html><head><title>image-gen placeholder</title></head><body>image-gen dist not built (npm missing)</body></html>' > dist/index.html
|
||||
fi
|
||||
echo " web/image-gen built."
|
||||
|
||||
# Step 3: Build Docker image
|
||||
echo "[3/4] Building Docker image: $IMAGE_NAME"
|
||||
cd "$PROJECT_DIR"
|
||||
docker build -t "$IMAGE_NAME" .
|
||||
echo " Image built."
|
||||
|
||||
# Step 4: Restart production containers
|
||||
echo "[4/4] Restarting production containers..."
|
||||
sudo -u www docker compose -f "$COMPOSE_FILE" up -d
|
||||
echo " Production deployed."
|
||||
|
||||
echo "========================================="
|
||||
echo " Done! Production at http://localhost:3001"
|
||||
echo "========================================="
|
||||
@@ -0,0 +1,132 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
$root = Split-Path -Parent $PSCommandPath
|
||||
$backendPort = 3000
|
||||
$frontendPort = 5173
|
||||
|
||||
function Test-TcpPort {
|
||||
param(
|
||||
[Parameter(Mandatory = $true)]
|
||||
[int]$Port
|
||||
)
|
||||
|
||||
$client = [System.Net.Sockets.TcpClient]::new()
|
||||
try {
|
||||
$task = $client.ConnectAsync("127.0.0.1", $Port)
|
||||
if (-not $task.Wait(500)) {
|
||||
return $false
|
||||
}
|
||||
return $client.Connected
|
||||
}
|
||||
catch {
|
||||
return $false
|
||||
}
|
||||
finally {
|
||||
$client.Dispose()
|
||||
}
|
||||
}
|
||||
|
||||
function Assert-PortFree {
|
||||
param(
|
||||
[Parameter(Mandatory = $true)]
|
||||
[string]$Name,
|
||||
[Parameter(Mandatory = $true)]
|
||||
[int]$Port
|
||||
)
|
||||
|
||||
if (Test-TcpPort -Port $Port) {
|
||||
Write-Host "[error] $Name 端口 $Port 已被占用,请先停止占用该端口的进程" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
function Wait-Port {
|
||||
param(
|
||||
[Parameter(Mandatory = $true)]
|
||||
[string]$Name,
|
||||
[Parameter(Mandatory = $true)]
|
||||
[int]$Port,
|
||||
[Parameter(Mandatory = $true)]
|
||||
[System.Diagnostics.Process]$Process,
|
||||
[int]$TimeoutSeconds = 90
|
||||
)
|
||||
|
||||
$deadline = (Get-Date).AddSeconds($TimeoutSeconds)
|
||||
while ((Get-Date) -lt $deadline) {
|
||||
if ($Process.HasExited) {
|
||||
Write-Host "[error] $Name 进程已退出,退出码 $($Process.ExitCode),请查看对应窗口日志" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
if (Test-TcpPort -Port $Port) {
|
||||
Write-Host "[$Name] 已就绪 (port $Port)" -ForegroundColor Green
|
||||
return
|
||||
}
|
||||
Start-Sleep -Milliseconds 500
|
||||
}
|
||||
|
||||
Write-Host "[error] $Name 未在 $TimeoutSeconds 秒内监听端口 $Port,请查看对应窗口日志" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
# 0. 初始化 PATH
|
||||
$env:Path = "C:\Program Files\Go\bin;C:\Users\Chaos\.bun\bin;$env:Path"
|
||||
|
||||
# 1. 检查 .env
|
||||
if (-not (Test-Path "$root\.env")) {
|
||||
Write-Host "[setup] 复制 .env.example -> .env" -ForegroundColor Yellow
|
||||
Copy-Item "$root\.env.example" "$root\.env"
|
||||
}
|
||||
|
||||
# 2. 检查前端依赖
|
||||
if (-not (Test-Path "$root\web\default\node_modules")) {
|
||||
Write-Host "[setup] 安装前端依赖..." -ForegroundColor Yellow
|
||||
Set-Location "$root\web\default"
|
||||
bun install
|
||||
Set-Location $root
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "[error] bun install 失败" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 3. 创建 go:embed 所需目录
|
||||
$embedDirs = @("web\default\dist", "web\classic\dist", "web\image-gen\dist")
|
||||
foreach ($dir in $embedDirs) {
|
||||
$fullPath = Join-Path $root $dir
|
||||
if (-not (Test-Path $fullPath)) {
|
||||
New-Item -ItemType Directory -Force -Path $fullPath | Out-Null
|
||||
}
|
||||
$indexFile = Join-Path $fullPath "index.html"
|
||||
if (-not (Test-Path $indexFile)) {
|
||||
Set-Content -Path $indexFile -Value "<!DOCTYPE html><html></html>"
|
||||
}
|
||||
}
|
||||
|
||||
# 4. 初始化 PATH
|
||||
$goPath = "C:\Program Files\Go\bin"
|
||||
$bunPath = "C:\Users\Chaos\.bun\bin"
|
||||
$initPath = "`$env:Path = '$goPath;$bunPath;' + `$env:Path;"
|
||||
|
||||
# 5. 检查端口占用
|
||||
Assert-PortFree -Name "Backend" -Port $backendPort
|
||||
Assert-PortFree -Name "Frontend" -Port $frontendPort
|
||||
|
||||
# 6. 启动后端
|
||||
Write-Host "[backend] 启动 API 服务 (port $backendPort)..." -ForegroundColor Green
|
||||
$backendJob = Start-Process -FilePath "powershell" -ArgumentList "-NoExit", "-Command", "$initPath Set-Location '$root'; Write-Host '=== Backend :$backendPort ===' -ForegroundColor Cyan; go run main.go" -PassThru
|
||||
|
||||
# 7. 启动前端
|
||||
Write-Host "[frontend] 启动前端开发服务 (port $frontendPort)..." -ForegroundColor Green
|
||||
$frontendJob = Start-Process -FilePath "powershell" -ArgumentList "-NoExit", "-Command", "$initPath Set-Location '$root\web\default'; Write-Host '=== Frontend :$frontendPort ===' -ForegroundColor Magenta; bun run dev" -PassThru
|
||||
|
||||
# 8. 等待服务就绪
|
||||
Wait-Port -Name "Backend" -Port $backendPort -Process $backendJob -TimeoutSeconds 90
|
||||
Wait-Port -Name "Frontend" -Port $frontendPort -Process $frontendJob -TimeoutSeconds 60
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "==============================" -ForegroundColor White
|
||||
Write-Host " Backend : http://localhost:$backendPort" -ForegroundColor Cyan
|
||||
Write-Host " Frontend : http://localhost:$frontendPort" -ForegroundColor Magenta
|
||||
Write-Host "==============================" -ForegroundColor White
|
||||
Write-Host ""
|
||||
Write-Host "关闭窗口即可停止服务" -ForegroundColor Gray
|
||||
+12
-3
@@ -1,9 +1,14 @@
|
||||
# Frontend Development - Backend built from local source
|
||||
#
|
||||
# Usage:
|
||||
# Usage (Docker backend):
|
||||
# 1. docker compose -f docker-compose.dev.yml up -d
|
||||
# 2. cd web && bun install && bun run dev
|
||||
# 3. Open http://localhost:3001 (Rsbuild dev server, API auto-proxied to :3000)
|
||||
# 2. cd web/default && bun install && bun run dev
|
||||
# 3. Open http://localhost:3000 (Rsbuild dev server, API auto-proxied to :3000)
|
||||
#
|
||||
# Usage (Local Go backend):
|
||||
# 1. docker compose -f docker-compose.dev.yml up -d postgres redis
|
||||
# 2. PORT=3002 SQL_DSN="postgresql://root:123456@localhost:5432/new-api" REDIS_CONN_STRING="redis://localhost:6379" go run main.go
|
||||
# 3. cd web/default && VITE_REACT_APP_SERVER_URL=http://localhost:3002 bun run dev
|
||||
#
|
||||
# Rebuild backend after Go code changes:
|
||||
# docker compose -f docker-compose.dev.yml up -d --build new-api
|
||||
@@ -43,6 +48,8 @@ services:
|
||||
image: redis:7-alpine
|
||||
container_name: new-api-dev-redis
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- dev-network
|
||||
|
||||
@@ -54,6 +61,8 @@ services:
|
||||
POSTGRES_USER: root
|
||||
POSTGRES_PASSWORD: 123456
|
||||
POSTGRES_DB: new-api
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- dev_pg_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
|
||||
+6
-6
@@ -26,11 +26,11 @@ type ImageRequest struct {
|
||||
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Images json.RawMessage `json:"images,omitempty"`
|
||||
Mask json.RawMessage `json:"mask,omitempty"`
|
||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Images json.RawMessage `json:"images,omitempty"`
|
||||
Mask json.RawMessage `json:"mask,omitempty"`
|
||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// zhipu 4v
|
||||
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
@@ -163,7 +163,7 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||
return false
|
||||
return i.Stream != nil && *i.Stream
|
||||
}
|
||||
|
||||
func (i *ImageRequest) SetModelName(modelName string) {
|
||||
|
||||
@@ -47,6 +47,12 @@ var classicBuildFS embed.FS
|
||||
//go:embed web/classic/dist/index.html
|
||||
var classicIndexPage []byte
|
||||
|
||||
//go:embed web/image-gen/dist
|
||||
var imageGenBuildFS embed.FS
|
||||
|
||||
//go:embed web/image-gen/dist/index.html
|
||||
var imageGenIndexPage []byte
|
||||
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -195,6 +201,8 @@ func main() {
|
||||
DefaultIndexPage: indexPage,
|
||||
ClassicBuildFS: classicBuildFS,
|
||||
ClassicIndexPage: classicIndexPage,
|
||||
ImageGenBuildFS: imageGenBuildFS,
|
||||
ImageGenIndexPage: imageGenIndexPage,
|
||||
})
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
|
||||
@@ -59,6 +59,7 @@ type CustomOAuthProvider struct {
|
||||
// Advanced options
|
||||
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
|
||||
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
|
||||
PKCEEnabled bool `json:"pkce_enabled" gorm:"default:false"` // Enable PKCE (Proof Key for Code Exchange)
|
||||
AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info
|
||||
AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
|
||||
|
||||
|
||||
@@ -88,6 +88,33 @@ func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// RecordUserRestoreLog writes an audit-log entry whenever a soft-deleted user
|
||||
// is automatically restored (e.g. by logging in again via password or OAuth).
|
||||
// `source` describes the trigger, e.g. "github", "linuxdo", "telegram", "password".
|
||||
// `callerIp` may be empty when the call originates from the model layer.
|
||||
func RecordUserRestoreLog(userId int, source string, callerIp string) {
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
other := map[string]interface{}{}
|
||||
if source != "" {
|
||||
other["restore_source"] = source
|
||||
}
|
||||
if callerIp != "" {
|
||||
other["caller_ip"] = callerIp
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: LogTypeSystem,
|
||||
Content: fmt.Sprintf("软删除用户被自动恢复,来源 %s", source),
|
||||
Ip: callerIp,
|
||||
Other: common.MapToJsonStr(other),
|
||||
}
|
||||
if err := LOG_DB.Create(log).Error; err != nil {
|
||||
common.SysLog("failed to record user restore log: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func RecordLog(userId int, logType int, content string) {
|
||||
if logType == LogTypeConsume && !common.LogConsumeEnabled {
|
||||
return
|
||||
@@ -531,6 +558,24 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
return token
|
||||
}
|
||||
|
||||
func SumTotalConsumeTokens() (int64, error) {
|
||||
type tokenStat struct {
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
|
||||
var stat tokenStat
|
||||
err := LOG_DB.Model(&Log{}).
|
||||
Select("sum(prompt_tokens) as prompt_tokens, sum(completion_tokens) as completion_tokens").
|
||||
Where("type = ?", LogTypeConsume).
|
||||
Scan(&stat).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return stat.PromptTokens + stat.CompletionTokens, nil
|
||||
}
|
||||
|
||||
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
|
||||
var total int64 = 0
|
||||
|
||||
|
||||
+38
-13
@@ -589,18 +589,40 @@ func (user *User) HardDelete() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) Restore() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Unscoped().Model(user).Update("deleted_at", nil).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := invalidateUserCache(user.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
user.DeletedAt = gorm.DeletedAt{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) RestoreIfDeleted(source string, callerIp string) error {
|
||||
if !user.DeletedAt.Valid {
|
||||
return nil
|
||||
}
|
||||
if err := user.Restore(); err != nil {
|
||||
return err
|
||||
}
|
||||
RecordUserRestoreLog(user.Id, source, callerIp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAndFill check password & user status
|
||||
func (user *User) ValidateAndFill() (err error) {
|
||||
// When querying with struct, GORM will only query with non-zero fields,
|
||||
// that means if your field's value is 0, '', false or other zero values,
|
||||
// it won't be used to build query conditions
|
||||
func (user *User) ValidateAndFill(callerIp string) (err error) {
|
||||
password := user.Password
|
||||
username := strings.TrimSpace(user.Username)
|
||||
if username == "" || password == "" {
|
||||
return ErrUserEmptyCredentials
|
||||
}
|
||||
// find by username or email
|
||||
err = DB.Where("username = ? OR email = ?", username, username).First(user).Error
|
||||
err = DB.Unscoped().Where("username = ? OR email = ?", username, username).First(user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrInvalidCredentials
|
||||
@@ -611,6 +633,9 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if !okay || user.Status != common.UserStatusEnabled {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
if err := user.RestoreIfDeleted("password", callerIp); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDatabase, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -634,7 +659,7 @@ func (user *User) FillUserByGitHubId() error {
|
||||
if user.GitHubId == "" {
|
||||
return errors.New("GitHub id 为空!")
|
||||
}
|
||||
DB.Where(User{GitHubId: user.GitHubId}).First(user)
|
||||
DB.Unscoped().Where(User{GitHubId: user.GitHubId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -650,7 +675,7 @@ func (user *User) FillUserByDiscordId() error {
|
||||
if user.DiscordId == "" {
|
||||
return errors.New("discord id 为空!")
|
||||
}
|
||||
DB.Where(User{DiscordId: user.DiscordId}).First(user)
|
||||
DB.Unscoped().Where(User{DiscordId: user.DiscordId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -658,7 +683,7 @@ func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
}
|
||||
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||
DB.Unscoped().Where(User{OidcId: user.OidcId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -666,7 +691,7 @@ func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
}
|
||||
DB.Where(User{WeChatId: user.WeChatId}).First(user)
|
||||
DB.Unscoped().Where(User{WeChatId: user.WeChatId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -674,7 +699,7 @@ func (user *User) FillUserByTelegramId() error {
|
||||
if user.TelegramId == "" {
|
||||
return errors.New("Telegram id 为空!")
|
||||
}
|
||||
err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
|
||||
err := DB.Unscoped().Where(User{TelegramId: user.TelegramId}).First(user).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("该 Telegram 账户未绑定")
|
||||
}
|
||||
@@ -698,7 +723,7 @@ func IsDiscordIdAlreadyTaken(discordId string) bool {
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
return DB.Unscoped().Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
@@ -1057,7 +1082,7 @@ func (user *User) FillUserByLinuxDOId() error {
|
||||
if user.LinuxDOId == "" {
|
||||
return errors.New("linux do id is empty")
|
||||
}
|
||||
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
|
||||
err := DB.Unscoped().Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
|
||||
type UserOAuthBinding struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
|
||||
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
|
||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
|
||||
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
|
||||
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
|
||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error)
|
||||
}
|
||||
|
||||
var user User
|
||||
err = DB.First(&user, binding.UserId).Error
|
||||
err = DB.Unscoped().First(&user, binding.UserId).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+42
-1
@@ -94,12 +94,53 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
|
||||
|
||||
// Handle pkce.xxx format from some OAuth providers (e.g., dc.hhhl.cc)
|
||||
// The code is in format: pkce.base64json({token, codeChallenge, codeChallengeMethod})
|
||||
// We need to send the FULL pkce.xxx code to the token endpoint, not just the extracted token
|
||||
var extractedCodeChallenge string
|
||||
if strings.HasPrefix(code, "pkce.") {
|
||||
encodedPart := code[5:] // Remove "pkce." prefix
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(encodedPart)
|
||||
if err == nil {
|
||||
var pkceData struct {
|
||||
Token string `json:"token"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
}
|
||||
if jsonErr := common.Unmarshal(decoded, &pkceData); jsonErr == nil && pkceData.Token != "" {
|
||||
extractedCodeChallenge = pkceData.CodeChallenge
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: parsed pkce format, token=%s..., codeChallenge=%s",
|
||||
p.config.Slug, pkceData.Token[:min(len(pkceData.Token), 10)], extractedCodeChallenge)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
|
||||
values := url.Values{}
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("code", code)
|
||||
values.Set("code", code) // Send the full pkce.xxx code
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
// Log all parameters being sent for debugging
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: sending to %s with params: grant_type=authorization_code, code=%s, redirect_uri=%s, client_id=%s",
|
||||
p.config.Slug, p.config.TokenEndpoint, code[:min(len(code), 20)], redirectUri, p.config.ClientId)
|
||||
|
||||
// Add PKCE code_verifier if enabled
|
||||
if p.config.PKCEEnabled && c != nil {
|
||||
if codeVerifier, exists := c.Get("pkce_code_verifier"); exists {
|
||||
if verifier, ok := codeVerifier.(string); ok && verifier != "" {
|
||||
values.Set("code_verifier", verifier)
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: PKCE code_verifier added", p.config.Slug)
|
||||
}
|
||||
}
|
||||
// Some OAuth providers expect code_challenge to be sent during token exchange
|
||||
if extractedCodeChallenge != "" {
|
||||
values.Set("code_challenge", extractedCodeChallenge)
|
||||
values.Set("code_challenge_method", "S256")
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: PKCE code_challenge added: %s", p.config.Slug, extractedCodeChallenge)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine auth style
|
||||
authStyle := p.config.AuthStyle
|
||||
if authStyle == AuthStyleAutoDetect {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
channelconstant "github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -79,9 +81,23 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request.Temperature != nil && isTemperatureOneOnlyModel(getUpstreamModelName(info, request.Model)) && *request.Temperature != 1.0 {
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func getUpstreamModelName(info *relaycommon.RelayInfo, fallback string) string {
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
return info.UpstreamModelName
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isTemperatureOneOnlyModel(model string) bool {
|
||||
return strings.EqualFold(model, "kimi-k2.6")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package moonshot
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIRequestKimiK26UsesOnlyAllowedTemperature(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.6",
|
||||
Temperature: common.GetPointer[float64](0.7),
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.6",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, convertedRequest.Temperature)
|
||||
require.Equal(t, 1.0, *convertedRequest.Temperature)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestKimiK26KeepsOmittedTemperatureOmitted(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.6",
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.6",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.Nil(t, convertedRequest.Temperature)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestOtherMoonshotModelKeepsTemperature(t *testing.T) {
|
||||
request := &dto.GeneralOpenAIRequest{
|
||||
Model: "kimi-k2.5",
|
||||
Temperature: common.GetPointer[float64](0.7),
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "kimi-k2.5",
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertOpenAIRequest(nil, info, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
convertedRequest, ok := converted.(*dto.GeneralOpenAIRequest)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, convertedRequest.Temperature)
|
||||
require.Equal(t, 0.7, *convertedRequest.Temperature)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -439,10 +440,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
// 使用已解析的 multipart 表单,避免重复解析
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse multipart form: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
c.Request.MultipartForm = form
|
||||
c.Request.PostForm = url.Values(form.Value)
|
||||
mf = form
|
||||
}
|
||||
|
||||
// 写入所有非文件字段
|
||||
@@ -625,7 +629,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
||||
if info.IsStream {
|
||||
usage, err = OpenaiImageStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = OpenaiImageHandler(c, info, resp)
|
||||
}
|
||||
case relayconstant.RelayModeRerank:
|
||||
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||
case relayconstant.RelayModeResponses:
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConvertImageEditRequestMultipart verifies that ConvertImageRequest
|
||||
// re-serializes multipart image edit requests with all fields (including
|
||||
// stream) and the file intact, both when the form was already parsed and when
|
||||
// it must be re-parsed from the reusable body.
|
||||
func TestConvertImageEditRequestMultipart(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
newMultipartContext := func(t *testing.T, prompt string) *gin.Context {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||
require.NoError(t, writer.WriteField("prompt", prompt))
|
||||
require.NoError(t, writer.WriteField("stream", "true"))
|
||||
require.NoError(t, writer.WriteField("partial_images", "3"))
|
||||
part, err := writer.CreateFormFile("image", "input.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write([]byte("fake image"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return c
|
||||
}
|
||||
|
||||
convertAndReplay := func(t *testing.T, c *gin.Context, prompt string) {
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayMode: relayconstant.RelayModeImagesEdits,
|
||||
}
|
||||
request := dto.ImageRequest{
|
||||
Model: "gpt-image-1",
|
||||
Prompt: prompt,
|
||||
Stream: common.GetPointer(true),
|
||||
}
|
||||
|
||||
converted, err := (&Adaptor{}).ConvertImageRequest(c, info, request)
|
||||
require.NoError(t, err)
|
||||
convertedBody, ok := converted.(*bytes.Buffer)
|
||||
require.True(t, ok)
|
||||
|
||||
replayedRequest := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(convertedBody.Bytes()))
|
||||
replayedRequest.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
require.NoError(t, replayedRequest.ParseMultipartForm(32<<20))
|
||||
|
||||
require.Equal(t, "gpt-image-1", replayedRequest.PostForm.Get("model"))
|
||||
require.Equal(t, prompt, replayedRequest.PostForm.Get("prompt"))
|
||||
require.Equal(t, "true", replayedRequest.PostForm.Get("stream"))
|
||||
require.Equal(t, "3", replayedRequest.PostForm.Get("partial_images"))
|
||||
require.Len(t, replayedRequest.MultipartForm.File["image"], 1)
|
||||
|
||||
file, err := replayedRequest.MultipartForm.File["image"][0].Open()
|
||||
require.NoError(t, err)
|
||||
defer file.Close()
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("fake image"), fileBytes)
|
||||
}
|
||||
|
||||
t.Run("with pre-parsed form", func(t *testing.T) {
|
||||
prompt := "edit this image"
|
||||
c := newMultipartContext(t, prompt)
|
||||
require.NoError(t, c.Request.ParseMultipartForm(32<<20))
|
||||
|
||||
convertAndReplay(t, c, prompt)
|
||||
})
|
||||
|
||||
t.Run("re-parses reusable body when form is missing", func(t *testing.T) {
|
||||
prompt := "edit without pre-parsed form"
|
||||
c := newMultipartContext(t, prompt)
|
||||
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
require.NoError(t, err)
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
c.Request.MultipartForm = nil
|
||||
c.Request.PostForm = nil
|
||||
|
||||
convertAndReplay(t, c, prompt)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newImageTestContext(t *testing.T, body, contentType string, isStream bool) (*gin.Context, *httptest.ResponseRecorder, *http.Response, *relaycommon.RelayInfo) {
|
||||
t.Helper()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: http.Header{"Content-Type": []string{contentType}},
|
||||
}
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
IsStream: isStream,
|
||||
}
|
||||
return c, recorder, resp, info
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerForwardsSSEAndUsage covers the core SSE path:
|
||||
// chunks are forwarded with rebuilt event lines, usage is extracted and
|
||||
// normalized (input_tokens -> prompt_tokens with details), and [DONE] is
|
||||
// re-emitted to the client.
|
||||
func TestOpenaiImageStreamHandlerForwardsSSEAndUsage(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
body := strings.Join([]string{
|
||||
`event: image_generation.partial_image`,
|
||||
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||
``,
|
||||
`data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, usage.PromptTokens)
|
||||
require.Equal(t, 4, usage.CompletionTokens)
|
||||
require.Equal(t, 7, usage.TotalTokens)
|
||||
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||
require.Contains(t, recorder.Body.String(), `event: image_generation.partial_image`)
|
||||
require.Contains(t, recorder.Body.String(), `data: {"type":"image_generation.partial_image","b64_json":"partial"}`)
|
||||
require.Contains(t, recorder.Body.String(), `data: {"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`)
|
||||
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerWrapsJSONResponse covers the non-SSE fallback:
|
||||
// a JSON upstream response is wrapped into pseudo-SSE completed events.
|
||||
func TestOpenaiImageStreamHandlerWrapsJSONResponse(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
body := `{"created":1710000000,"data":[{"b64_json":"final","revised_prompt":"draw a cat"}],"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"input_tokens_details":{"image_tokens":2,"text_tokens":1}}}`
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, usage.PromptTokens)
|
||||
require.Equal(t, 4, usage.CompletionTokens)
|
||||
require.Equal(t, 7, usage.TotalTokens)
|
||||
require.Equal(t, 2, usage.PromptTokensDetails.ImageTokens)
|
||||
require.Equal(t, 1, usage.PromptTokensDetails.TextTokens)
|
||||
require.Equal(t, "text/event-stream", recorder.Header().Get("Content-Type"))
|
||||
require.Empty(t, recorder.Header().Get("Content-Length"))
|
||||
require.Contains(t, recorder.Body.String(), `event: image_generation.completed`)
|
||||
require.Contains(t, recorder.Body.String(), `"type":"image_generation.completed"`)
|
||||
require.Contains(t, recorder.Body.String(), `"b64_json":"final"`)
|
||||
require.Contains(t, recorder.Body.String(), `"revised_prompt":"draw a cat"`)
|
||||
require.Contains(t, recorder.Body.String(), `data: [DONE]`)
|
||||
}
|
||||
|
||||
// TestOpenaiImageHandlersReturnJSONError covers JSON error responses for both
|
||||
// entry points: the non-streaming handler and the stream handler's non-SSE
|
||||
// fallback. Neither must leak the error body to the client.
|
||||
func TestOpenaiImageHandlersReturnJSONError(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
body := `{"error":{"message":"content moderation failed","type":"upstream_error","code":"content_moderation_failed","status":502}}`
|
||||
|
||||
t.Run("non-streaming handler", func(t *testing.T) {
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", false)
|
||||
|
||||
usage, err := OpenaiImageHandler(c, info, resp)
|
||||
require.Nil(t, usage)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||
oaiError := err.ToOpenAIError()
|
||||
require.Equal(t, "content moderation failed", oaiError.Message)
|
||||
require.Equal(t, "upstream_error", oaiError.Type)
|
||||
require.Equal(t, "content_moderation_failed", oaiError.Code)
|
||||
require.Empty(t, recorder.Body.String())
|
||||
})
|
||||
|
||||
t.Run("stream handler JSON fallback", func(t *testing.T) {
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "application/json", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, usage)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, http.StatusOK, err.StatusCode)
|
||||
require.Equal(t, "content moderation failed", err.ToOpenAIError().Message)
|
||||
require.Empty(t, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
// TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent verifies that an error
|
||||
// event inside the SSE stream is recorded as a soft error while the payload is
|
||||
// still forwarded to the client.
|
||||
func TestOpenaiImageStreamHandlerRecordsUpstreamErrorEvent(t *testing.T) {
|
||||
oldMode := gin.Mode()
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Cleanup(func() { gin.SetMode(oldMode) })
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
body := strings.Join([]string{
|
||||
`event: image_generation.partial_image`,
|
||||
`data: {"type":"image_generation.partial_image","b64_json":"partial"}`,
|
||||
``,
|
||||
`event: error`,
|
||||
`data: {"type":"upstream_error","error":{"message":"stream error: stream ID 77; INTERNAL_ERROR; received from peer"}}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
c, recorder, resp, info := newImageTestContext(t, body, "text/event-stream", true)
|
||||
|
||||
usage, err := OpenaiImageStreamHandler(c, info, resp)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.NotNil(t, info.StreamStatus)
|
||||
require.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
||||
require.True(t, info.StreamStatus.HasErrors())
|
||||
require.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
require.Contains(t, info.StreamStatus.Errors[0].Message, "INTERNAL_ERROR")
|
||||
// The scanner strips the upstream "event: error" line; the event name is
|
||||
// rebuilt from the JSON "type" field (upstream_error). The error message
|
||||
// is still forwarded in the data: payload (stream ID 77).
|
||||
require.Contains(t, recorder.Body.String(), `event: upstream_error`)
|
||||
require.Contains(t, recorder.Body.String(), `stream ID 77`)
|
||||
}
|
||||
@@ -14,12 +14,9 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
@@ -293,421 +290,3 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
|
||||
return &simpleResponse.Usage, nil
|
||||
}
|
||||
|
||||
func streamTTSResponse(c *gin.Context, resp *http.Response) {
|
||||
c.Writer.WriteHeaderNow()
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.LogWarn(c, "streaming not supported")
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
logger.LogWarn(c, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := resp.Body.Read(buffer)
|
||||
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
|
||||
if n > 0 {
|
||||
if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
|
||||
logger.LogError(c, writeErr.Error())
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
logger.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
if usage == nil || totalUsage == nil {
|
||||
return fmt.Errorf("invalid usage pointer")
|
||||
}
|
||||
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// Once we've written to the client, we should not return errors anymore
|
||||
// because the upstream has already consumed resources and returned content
|
||||
// We should still perform billing even if parsing fails
|
||||
// format
|
||||
if usageResp.InputTokens > 0 {
|
||||
usageResp.PromptTokens += usageResp.InputTokens
|
||||
}
|
||||
if usageResp.OutputTokens > 0 {
|
||||
usageResp.CompletionTokens += usageResp.OutputTokens
|
||||
}
|
||||
if usageResp.InputTokensDetails != nil {
|
||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||
}
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||
if info == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch info.ChannelType {
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Usage struct {
|
||||
PromptTokensDetails struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.CachedTokens != nil {
|
||||
return *payload.Usage.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.PromptCacheHitTokens != nil {
|
||||
return *payload.Usage.PromptCacheHitTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenaiImageHandler handles non-streaming OpenAI image responses
|
||||
// (generations/edits), returning the parsed usage for billing.
|
||||
func OpenaiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
// normalizeOpenAIUsage maps the OpenAI Images usage shape (input_tokens /
|
||||
// output_tokens / input_tokens_details) onto the canonical prompt/completion
|
||||
// fields. It is used only on the OpenAI image relay paths (generations/edits,
|
||||
// streaming and non-streaming): the image API never returns prompt_tokens /
|
||||
// completion_tokens, so the overwrite (=) semantics here are equivalent to the
|
||||
// previous additive (+=) behavior while avoiding any future double-counting if
|
||||
// both field sets are ever populated. Do not reuse this on chat/embedding paths
|
||||
// without revisiting the overwrite semantics.
|
||||
func normalizeOpenAIUsage(usage *dto.Usage) {
|
||||
if usage == nil {
|
||||
return
|
||||
}
|
||||
if usage.InputTokens != 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
if usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
usage.PromptTokensDetails.CachedCreationTokens = usage.InputTokensDetails.CachedCreationTokens
|
||||
usage.PromptTokensDetails.ImageTokens = usage.InputTokensDetails.ImageTokens
|
||||
usage.PromptTokensDetails.TextTokens = usage.InputTokensDetails.TextTokens
|
||||
usage.PromptTokensDetails.AudioTokens = usage.InputTokensDetails.AudioTokens
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
func OpenaiImageStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
logger.LogError(c, "invalid image stream response")
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return OpenaiImageHandler(c, info, resp)
|
||||
}
|
||||
if !strings.Contains(contentType, "text/event-stream") {
|
||||
return OpenaiImageJSONAsStreamHandler(c, info, resp)
|
||||
}
|
||||
// Reuse the shared streaming engine (helper.StreamScannerHandler) so the
|
||||
// image streaming path gets the same ping keepalive, streaming-timeout
|
||||
// watchdog, client-disconnect detection, panic recovery and goroutine
|
||||
// cleanup as every other relay stream. The scanner delivers only the
|
||||
// "data:" payload, so the SSE "event:" line is rebuilt from the JSON "type"
|
||||
// field (real OpenAI image events keep event == type).
|
||||
usage := &dto.Usage{}
|
||||
var lastStreamData []byte
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
|
||||
raw := common.StringToByteSlice(data)
|
||||
lastStreamData = raw
|
||||
if isOpenAIImageStreamErrorEvent(raw) {
|
||||
// Record the error as a soft error; the scanner drives the final
|
||||
// EndReason. HasErrors() flags the failure for logging/handling.
|
||||
sr.Error(fmt.Errorf("%s", extractOpenAIImageStreamErrorMessage(raw)))
|
||||
}
|
||||
var usageResp dto.SimpleResponse
|
||||
if err := common.Unmarshal(raw, &usageResp); err == nil {
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
usage = &usageResp.Usage
|
||||
}
|
||||
}
|
||||
writeOpenaiImageStreamChunk(c, raw)
|
||||
})
|
||||
|
||||
// StreamScannerHandler consumes the upstream [DONE]; re-emit it so the
|
||||
// client still receives a terminal data: [DONE].
|
||||
if info != nil && info.StreamStatus != nil && info.StreamStatus.EndReason == relaycommon.StreamEndReasonDone {
|
||||
helper.Done(c)
|
||||
}
|
||||
|
||||
applyUsagePostProcessing(info, usage, lastStreamData)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// writeOpenaiImageStreamChunk rebuilds the SSE frame for an image stream chunk:
|
||||
// it emits an "event:" line derived from the JSON "type" field (when present)
|
||||
// followed by the verbatim "data:" payload, mirroring helper.ResponseChunkData.
|
||||
func writeOpenaiImageStreamChunk(c *gin.Context, data []byte) {
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
_ = common.Unmarshal(data, &payload)
|
||||
if eventName := strings.TrimSpace(payload.Type); eventName != "" {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", eventName)})
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(data)})
|
||||
_ = helper.FlushWriter(c)
|
||||
}
|
||||
|
||||
// isOpenAIImageStreamErrorEvent detects upstream error chunks by JSON content
|
||||
// only ("type" of error/upstream_error, or a non-empty "error" field). The SSE
|
||||
// "event:" line is not available here: StreamScannerHandler delivers only the
|
||||
// "data:" payload. A payload carrying just a "message" key is deliberately NOT
|
||||
// treated as an error to avoid false positives.
|
||||
func isOpenAIImageStreamErrorEvent(data []byte) bool {
|
||||
if !json.Valid(data) {
|
||||
return false
|
||||
}
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := common.Unmarshal(data, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
payloadType := strings.ToLower(strings.TrimSpace(payload.Type))
|
||||
return payloadType == "error" || payloadType == "upstream_error" || len(payload.Error) > 0
|
||||
}
|
||||
|
||||
func extractOpenAIImageStreamErrorMessage(data []byte) string {
|
||||
if len(data) == 0 || !json.Valid(data) {
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
var payload struct {
|
||||
Message string `json:"message"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := common.Unmarshal(data, &payload); err != nil {
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
if msg := strings.TrimSpace(payload.Message); msg != "" {
|
||||
return msg
|
||||
}
|
||||
if len(payload.Error) > 0 {
|
||||
var nested struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := common.Unmarshal(payload.Error, &nested); err == nil {
|
||||
if msg := strings.TrimSpace(nested.Message); msg != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
if msg := strings.TrimSpace(common.JsonRawMessageToString(payload.Error)); msg != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
return "upstream image stream returned error event"
|
||||
}
|
||||
|
||||
func OpenaiImageJSONAsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var imageResp dto.ImageResponse
|
||||
if err := common.Unmarshal(responseBody, &imageResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
_ = common.Unmarshal(responseBody, &usageResp)
|
||||
if oaiError := usageResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
normalizeOpenAIUsage(&usageResp.Usage)
|
||||
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
created := imageResp.Created
|
||||
if created == 0 {
|
||||
created = time.Now().Unix()
|
||||
}
|
||||
if info != nil {
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
for _, image := range imageResp.Data {
|
||||
payload := map[string]any{
|
||||
"type": "image_generation.completed",
|
||||
"created_at": created,
|
||||
}
|
||||
if image.Url != "" {
|
||||
payload["url"] = image.Url
|
||||
}
|
||||
if image.B64Json != "" {
|
||||
payload["b64_json"] = image.B64Json
|
||||
}
|
||||
if image.RevisedPrompt != "" {
|
||||
payload["revised_prompt"] = image.RevisedPrompt
|
||||
}
|
||||
if service.ValidUsage(&usageResp.Usage) {
|
||||
payload["usage"] = usageResp.Usage
|
||||
}
|
||||
if err := writeOpenaiImageStreamPayload(c, "image_generation.completed", payload); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
}
|
||||
if err := writeOpenaiImageStreamDone(c); err != nil {
|
||||
if info != nil && info.StreamStatus != nil {
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, err)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
if info != nil {
|
||||
info.ReceivedResponseCount += len(imageResp.Data)
|
||||
if info.StreamStatus == nil {
|
||||
info.StreamStatus = relaycommon.NewStreamStatus()
|
||||
}
|
||||
info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
|
||||
}
|
||||
return &usageResp.Usage, nil
|
||||
}
|
||||
|
||||
func writeOpenaiImageStreamPayload(c *gin.Context, eventName string, payload any) error {
|
||||
data, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if eventName != "" {
|
||||
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", data); err != nil {
|
||||
return err
|
||||
}
|
||||
return helper.FlushWriter(c)
|
||||
}
|
||||
|
||||
func writeOpenaiImageStreamDone(c *gin.Context) error {
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return helper.FlushWriter(c)
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = common.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
logger.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
if usage == nil || totalUsage == nil {
|
||||
return fmt.Errorf("invalid usage pointer")
|
||||
}
|
||||
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
)
|
||||
|
||||
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||
if info == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch info.ChannelType {
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Usage struct {
|
||||
PromptTokensDetails struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.CachedTokens != nil {
|
||||
return *payload.Usage.CachedTokens, true
|
||||
}
|
||||
if payload.Usage.PromptCacheHitTokens != nil {
|
||||
return *payload.Usage.PromptCacheHitTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
|
||||
func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Timings struct {
|
||||
CachedTokens *int `json:"cache_n"`
|
||||
} `json:"timings"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if payload.Timings.CachedTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *payload.Timings.CachedTokens, true
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||
usage, err = openai.OpenaiImageHandler(c, info, resp)
|
||||
case constant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGetAndValidOpenAIImageRequestMultipartStream verifies multipart image
|
||||
// edit parsing: the stream field is parsed and validated, and the request body
|
||||
// stays replayable for the upstream request.
|
||||
func TestGetAndValidOpenAIImageRequestMultipartStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
newContext := func(t *testing.T, streamValue string, withImage bool) (*gin.Context, string) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-1"))
|
||||
require.NoError(t, writer.WriteField("prompt", "edit this image"))
|
||||
require.NoError(t, writer.WriteField("stream", streamValue))
|
||||
if withImage {
|
||||
part, err := writer.CreateFormFile("image", "input.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write([]byte("fake image"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.NoError(t, writer.Close())
|
||||
originalBody := body.String()
|
||||
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/edits", &body)
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return c, originalBody
|
||||
}
|
||||
|
||||
t.Run("valid stream value keeps body replayable", func(t *testing.T) {
|
||||
c, originalBody := newContext(t, "true", true)
|
||||
|
||||
req, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Stream)
|
||||
require.True(t, *req.Stream)
|
||||
require.True(t, req.IsStream(c))
|
||||
|
||||
bodyAfterValidation, err := io.ReadAll(c.Request.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalBody, string(bodyAfterValidation))
|
||||
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "true", url.Values(form.Value).Get("stream"))
|
||||
require.Len(t, form.File["image"], 1)
|
||||
})
|
||||
|
||||
t.Run("invalid stream value is rejected", func(t *testing.T) {
|
||||
c, _ := newContext(t, "notabool", false)
|
||||
|
||||
_, err := GetAndValidOpenAIImageRequest(c, relayconstant.RelayModeImagesEdits)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid stream value")
|
||||
})
|
||||
}
|
||||
@@ -22,8 +22,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||
DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size
|
||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||
DefaultMaxScannerBufferSize = 128 << 20 // 64MB (64*1024*1024) default SSE buffer size
|
||||
DefaultPingInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
|
||||
@@ -631,7 +631,7 @@ func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
||||
assert.NotNil(t, info.StreamStatus)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
||||
func TestStreamScannerHandler_StreamStatus_ReplacesPreInitialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(5)
|
||||
@@ -643,7 +643,7 @@ func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
||||
StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
||||
|
||||
assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
||||
assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
||||
assert.Equal(t, 0, info.StreamStatus.TotalErrorCount())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -144,16 +146,25 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
_, err := c.MultipartForm()
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
formData := url.Values(form.Value)
|
||||
c.Request.MultipartForm = form
|
||||
c.Request.PostForm = formData
|
||||
imageRequest.Prompt = formData.Get("prompt")
|
||||
imageRequest.Model = formData.Get("model")
|
||||
imageRequest.N = common.GetPointer(uint(common.String2Int(formData.Get("n"))))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
if streamValue := strings.TrimSpace(formData.Get("stream")); streamValue != "" {
|
||||
stream, err := strconv.ParseBool(streamValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid stream value: %w", err)
|
||||
}
|
||||
imageRequest.Stream = common.GetPointer(stream)
|
||||
}
|
||||
if imageValue := formData.Get("image"); imageValue != "" {
|
||||
imageRequest.Image, _ = common.Marshal(imageValue)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/home_stats", controller.GetHomeStats)
|
||||
apiRouter.GET("/pricing", middleware.HeaderNavModuleAuth("pricing"), controller.GetPricing)
|
||||
perfMetricsRoute := apiRouter.Group("/perf-metrics")
|
||||
perfMetricsRoute.Use(middleware.HeaderNavModulePublicOrUserAuth("pricing"))
|
||||
@@ -50,6 +51,11 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.POST("/oauth/wechat/bind", middleware.CriticalRateLimit(), anonymousRequestBodyLimit, controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
apiRouter.GET("/hhhl/authorize", middleware.CriticalRateLimit(), controller.HHHLAuthorize)
|
||||
apiRouter.GET("/hhhl/callback", middleware.CriticalRateLimit(), controller.HHHLCallback)
|
||||
apiRouter.POST("/hhhl/token", controller.HHHLToken)
|
||||
apiRouter.GET("/hhhl/token", controller.HHHLToken)
|
||||
apiRouter.GET("/hhhl/userinfo", controller.HHHLUserInfo)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
@@ -18,7 +18,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.RouteTag("relay"))
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
modelsRouter.Use(middleware.TokenOrUserAuth())
|
||||
{
|
||||
modelsRouter.GET("", func(c *gin.Context) {
|
||||
switch {
|
||||
@@ -69,7 +69,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.RouteTag("relay"))
|
||||
relayV1Router.Use(middleware.SystemPerformanceCheck())
|
||||
relayV1Router.Use(middleware.TokenAuth())
|
||||
relayV1Router.Use(middleware.TokenOrUserAuth())
|
||||
relayV1Router.Use(middleware.ModelRequestRateLimit())
|
||||
{
|
||||
// WebSocket 路由(统一到 Relay)
|
||||
|
||||
+22
-4
@@ -13,12 +13,18 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ThemeAssets holds the embedded frontend assets for both themes.
|
||||
// ThemeAssets holds the embedded frontend assets for both themes and
|
||||
// the image-gen sub-app.
|
||||
type ThemeAssets struct {
|
||||
DefaultBuildFS embed.FS
|
||||
DefaultIndexPage []byte
|
||||
ClassicBuildFS embed.FS
|
||||
ClassicIndexPage []byte
|
||||
// ImageGen is the image-generation sub-app, served at /image-gen/.
|
||||
// It shares the same origin as the rest of new-api so /api/* and /v1/*
|
||||
// are reachable via the new-api session cookie (no CORS, no sk-key).
|
||||
ImageGenBuildFS embed.FS
|
||||
ImageGenIndexPage []byte
|
||||
}
|
||||
|
||||
func SetWebRouter(router *gin.Engine, assets ThemeAssets) {
|
||||
@@ -26,20 +32,32 @@ func SetWebRouter(router *gin.Engine, assets ThemeAssets) {
|
||||
classicFS := common.EmbedFolder(assets.ClassicBuildFS, "web/classic/dist")
|
||||
themeFS := common.NewThemeAwareFS(defaultFS, classicFS)
|
||||
|
||||
// image-gen sub-app: serve static files under /image-gen, fall back to
|
||||
// its index.html for unknown sub-paths (SPA).
|
||||
imageGenFS := common.EmbedFolder(assets.ImageGenBuildFS, "web/image-gen/dist")
|
||||
|
||||
router.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
router.Use(middleware.GlobalWebRateLimit())
|
||||
router.Use(middleware.Cache())
|
||||
router.Use(static.Serve("/image-gen", imageGenFS))
|
||||
router.Use(static.Serve("/", themeFS))
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
c.Set(middleware.RouteTagKey, "web")
|
||||
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
|
||||
uri := c.Request.RequestURI
|
||||
// API/relay/static paths are handled by their own routers — 404 cleanly.
|
||||
if strings.HasPrefix(uri, "/v1") || strings.HasPrefix(uri, "/api") || strings.HasPrefix(uri, "/assets") {
|
||||
controller.RelayNotFound(c)
|
||||
return
|
||||
}
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
if common.GetTheme() == "classic" {
|
||||
switch {
|
||||
case strings.HasPrefix(uri, "/image-gen"):
|
||||
// SPA fallback for the image-gen sub-app: any sub-path that didn't
|
||||
// hit a static file gets the image-gen index.html.
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", assets.ImageGenIndexPage)
|
||||
case common.GetTheme() == "classic":
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", assets.ClassicIndexPage)
|
||||
} else {
|
||||
default:
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", assets.DefaultIndexPage)
|
||||
}
|
||||
})
|
||||
|
||||
+31
-22
@@ -34,13 +34,8 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
|
||||
}
|
||||
|
||||
func InitHttpClient() {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
|
||||
}
|
||||
transport := newRelayTransport()
|
||||
transport.Proxy = http.ProxyFromEnvironment // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
|
||||
if common.TLSInsecureSkipVerify {
|
||||
transport.TLSClientConfig = common.InsecureTLSConfig
|
||||
}
|
||||
@@ -59,6 +54,30 @@ func InitHttpClient() {
|
||||
}
|
||||
}
|
||||
|
||||
func newRelayTransport() *http.Transport {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: !common.RelayDisableHTTP2,
|
||||
TLSHandshakeTimeout: time.Duration(common.RelayTLSHandshakeTimeout) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(common.RelayExpectContinueTimeout) * time.Second,
|
||||
}
|
||||
if common.RelayResponseHeaderTimeout > 0 {
|
||||
transport.ResponseHeaderTimeout = time.Duration(common.RelayResponseHeaderTimeout) * time.Second
|
||||
}
|
||||
if common.RelayForceIPv4 {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp4", addr)
|
||||
}
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
@@ -106,13 +125,8 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
}
|
||||
transport := newRelayTransport()
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
if common.TLSInsecureSkipVerify {
|
||||
transport.TLSClientConfig = common.InsecureTLSConfig
|
||||
}
|
||||
@@ -146,14 +160,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: time.Duration(common.RelayIdleConnTimeout) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
transport := newRelayTransport()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
if common.TLSInsecureSkipVerify {
|
||||
transport.TLSClientConfig = common.InsecureTLSConfig
|
||||
|
||||
Vendored
+2
-2
@@ -6,8 +6,8 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<!-- Primary Meta Tags -->
|
||||
<title>New API</title>
|
||||
<meta name="title" content="New API" />
|
||||
<title>BBLBB</title>
|
||||
<meta name="title" content="BBLBB" />
|
||||
<meta
|
||||
name="description"
|
||||
content="Unified AI API gateway and admin dashboard."
|
||||
|
||||
Vendored
+1
@@ -65,6 +65,7 @@ export default defineConfig(({ envMode }) => {
|
||||
},
|
||||
server: {
|
||||
host: '0.0.0.0',
|
||||
port: 5173,
|
||||
strictPort: true,
|
||||
proxy: devProxy,
|
||||
},
|
||||
|
||||
+1
-2
@@ -27,7 +27,6 @@ import {
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react'
|
||||
import type { Element } from 'hast'
|
||||
import { CheckIcon, CopyIcon } from 'lucide-react'
|
||||
import {
|
||||
type BundledLanguage,
|
||||
@@ -53,7 +52,7 @@ const CodeBlockContext = createContext<CodeBlockContextType>({
|
||||
|
||||
const lineNumberTransformer: ShikiTransformer = {
|
||||
name: 'line-numbers',
|
||||
line(node: Element, line: number) {
|
||||
line(node, line) {
|
||||
node.children.unshift({
|
||||
type: 'element',
|
||||
tagName: 'span',
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
# Data Table Components
|
||||
|
||||
This package keeps a stable public API through `index.ts`; feature code should
|
||||
continue importing from `@/components/data-table`.
|
||||
|
||||
- `core/`: TanStack table rendering primitives, headers, rows, pagination,
|
||||
loading, empty states, and pinned-column behavior.
|
||||
- `layout/`: responsive page-level composition that combines toolbar, desktop
|
||||
table, mobile list, bulk actions, and pagination placement.
|
||||
- `toolbar/`: filter/search/view-option controls and selection action toolbar.
|
||||
- `static/`: lightweight table rendering for local/static arrays that do not
|
||||
need TanStack state.
|
||||
- `hooks/`: table state and filter hooks.
|
||||
|
||||
Keep feature-specific columns, actions, and dialogs inside their feature
|
||||
folders. Shared table code belongs here only when it is reusable across more
|
||||
than one feature.
|
||||
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import { cn } from '@/lib/utils'
|
||||
import type { DataTableColumnClassName, DataTablePinnedColumn } from './types'
|
||||
|
||||
export function getResolvedColumnClassName(
|
||||
getColumnClassName?: DataTableColumnClassName,
|
||||
pinnedColumns?: DataTablePinnedColumn[]
|
||||
): DataTableColumnClassName {
|
||||
return getResolvedColumnClassNameFromMap(
|
||||
getColumnClassName,
|
||||
getPinnedColumnMap(pinnedColumns)
|
||||
)
|
||||
}
|
||||
|
||||
export function getResolvedColumnClassNameFromMap(
|
||||
getColumnClassName?: DataTableColumnClassName,
|
||||
pinnedColumnById?: Map<string, DataTablePinnedColumn>
|
||||
): DataTableColumnClassName {
|
||||
return (columnId, kind) => {
|
||||
const customClassName = getColumnClassName?.(columnId, kind)
|
||||
const pinnedColumn = pinnedColumnById?.get(columnId)
|
||||
|
||||
if (!pinnedColumn) return customClassName
|
||||
|
||||
return cn(customClassName, getPinnedColumnClassName(pinnedColumn, kind))
|
||||
}
|
||||
}
|
||||
|
||||
export function getPinnedColumnMap(pinnedColumns?: DataTablePinnedColumn[]) {
|
||||
if (!pinnedColumns?.length) return undefined
|
||||
|
||||
return new Map(pinnedColumns.map((column) => [column.columnId, column]))
|
||||
}
|
||||
|
||||
function getPinnedColumnClassName(
|
||||
pinnedColumn: DataTablePinnedColumn,
|
||||
kind: 'header' | 'cell'
|
||||
) {
|
||||
const edgeClassName =
|
||||
pinnedColumn.side === 'left'
|
||||
? 'shadow-[8px_0_10px_-10px_hsl(var(--foreground))]'
|
||||
: 'shadow-[-8px_0_10px_-10px_hsl(var(--foreground))]'
|
||||
|
||||
return cn(
|
||||
'sticky whitespace-nowrap',
|
||||
pinnedColumn.side === 'left' ? 'left-0' : 'right-0',
|
||||
edgeClassName,
|
||||
kind === 'header'
|
||||
? 'bg-background z-30'
|
||||
: 'bg-background z-10 group-hover:bg-muted group-data-[state=selected]:bg-muted',
|
||||
pinnedColumn.className,
|
||||
kind === 'header'
|
||||
? pinnedColumn.headerClassName
|
||||
: pinnedColumn.cellClassName
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import type { Table as TanstackTable } from '@tanstack/react-table'
|
||||
|
||||
export function DataTableColgroup<TData>({
|
||||
table,
|
||||
}: {
|
||||
table: TanstackTable<TData>
|
||||
}) {
|
||||
return (
|
||||
<colgroup>
|
||||
{table.getVisibleLeafColumns().map((column) => (
|
||||
<col key={column.id} style={{ width: column.getSize() }} />
|
||||
))}
|
||||
</colgroup>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import { flexRender, type Table as TanstackTable } from '@tanstack/react-table'
|
||||
import { TableHead, TableHeader, TableRow } from '@/components/ui/table'
|
||||
import type { DataTableColumnClassName } from './types'
|
||||
|
||||
type DataTableHeaderProps<TData> = {
|
||||
table: TanstackTable<TData>
|
||||
applyHeaderSize?: boolean
|
||||
className?: string
|
||||
rowClassName?: string
|
||||
getColumnClassName?: DataTableColumnClassName
|
||||
}
|
||||
|
||||
export function DataTableHeader<TData>({
|
||||
table,
|
||||
applyHeaderSize,
|
||||
className,
|
||||
rowClassName,
|
||||
getColumnClassName,
|
||||
}: DataTableHeaderProps<TData>) {
|
||||
return (
|
||||
<TableHeader className={className}>
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id} className={rowClassName}>
|
||||
{headerGroup.headers.map((header) => (
|
||||
<TableHead
|
||||
key={header.id}
|
||||
colSpan={header.colSpan}
|
||||
className={getColumnClassName?.(header.column.id, 'header')}
|
||||
style={applyHeaderSize ? { width: header.getSize() } : undefined}
|
||||
>
|
||||
{header.isPlaceholder
|
||||
? null
|
||||
: flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext()
|
||||
)}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import type * as React from 'react'
|
||||
import { flexRender, type Row } from '@tanstack/react-table'
|
||||
import { TableCell, TableRow } from '@/components/ui/table'
|
||||
import type { DataTableColumnClassName } from './types'
|
||||
|
||||
type DataTableRowProps<TData> = {
|
||||
row: Row<TData>
|
||||
className?: string
|
||||
getColumnClassName?: DataTableColumnClassName
|
||||
} & Omit<React.ComponentProps<typeof TableRow>, 'children'>
|
||||
|
||||
export function DataTableRow<TData>({
|
||||
row,
|
||||
className,
|
||||
getColumnClassName,
|
||||
...rowProps
|
||||
}: DataTableRowProps<TData>) {
|
||||
return (
|
||||
<TableRow
|
||||
data-state={row.getIsSelected() ? 'selected' : undefined}
|
||||
className={className}
|
||||
{...rowProps}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell
|
||||
key={cell.id}
|
||||
className={getColumnClassName?.(cell.column.id, 'cell')}
|
||||
>
|
||||
{flexRender(cell.column.columnDef.cell, cell.getContext())}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,310 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import { type Row } from '@tanstack/react-table'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { Table, TableBody, TableCell, TableRow } from '@/components/ui/table'
|
||||
import {
|
||||
getPinnedColumnMap,
|
||||
getResolvedColumnClassNameFromMap,
|
||||
} from './column-pinning'
|
||||
import { DataTableColgroup } from './data-table-colgroup'
|
||||
import { DataTableHeader } from './data-table-header'
|
||||
import { DataTableRow } from './data-table-row'
|
||||
import { TableEmpty } from './table-empty'
|
||||
import { getTableSizeStyle } from './table-sizing'
|
||||
import { TableSkeleton } from './table-skeleton'
|
||||
import type {
|
||||
DataTableColumnClassName,
|
||||
DataTablePinnedColumn,
|
||||
DataTableViewProps,
|
||||
} from './types'
|
||||
|
||||
export type {
|
||||
DataTableColumnClassName,
|
||||
DataTablePinnedColumn,
|
||||
DataTableRenderRowHelpers,
|
||||
DataTableViewProps,
|
||||
} from './types'
|
||||
export { DataTableRow } from './data-table-row'
|
||||
|
||||
export function DataTableView<TData>(props: DataTableViewProps<TData>) {
|
||||
const rows = props.rows ?? props.table.getRowModel().rows
|
||||
const colSpan = props.table.getVisibleLeafColumns().length
|
||||
const columnClassName = useResolvedColumnClassName(
|
||||
props.getColumnClassName,
|
||||
props.pinnedColumns
|
||||
)
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'overflow-hidden rounded-lg border',
|
||||
props.containerClassName
|
||||
)}
|
||||
{...props.containerProps}
|
||||
>
|
||||
{props.splitHeader ? (
|
||||
<SplitHeaderTableView
|
||||
props={props}
|
||||
rows={rows}
|
||||
colSpan={colSpan}
|
||||
getColumnClassName={columnClassName}
|
||||
/>
|
||||
) : (
|
||||
<UnifiedTableView
|
||||
props={props}
|
||||
rows={rows}
|
||||
colSpan={colSpan}
|
||||
getColumnClassName={columnClassName}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function UnifiedTableView<TData>({
|
||||
props,
|
||||
rows,
|
||||
colSpan,
|
||||
getColumnClassName,
|
||||
}: {
|
||||
props: DataTableViewProps<TData>
|
||||
rows: Row<TData>[]
|
||||
colSpan: number
|
||||
getColumnClassName: DataTableColumnClassName
|
||||
}) {
|
||||
const tableSizing = getTableSizing(props)
|
||||
|
||||
return (
|
||||
<div className={props.tableContainerClassName}>
|
||||
<Table className={props.tableClassName} style={tableSizing.style}>
|
||||
{tableSizing.colgroup}
|
||||
<DataTableHeader
|
||||
table={props.table}
|
||||
applyHeaderSize={props.applyHeaderSize}
|
||||
className={props.tableHeaderClassName}
|
||||
rowClassName={props.tableHeaderRowClassName}
|
||||
getColumnClassName={getColumnClassName}
|
||||
/>
|
||||
{renderTableBody(props, rows, colSpan, getColumnClassName)}
|
||||
</Table>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function SplitHeaderTableView<TData>({
|
||||
props,
|
||||
rows,
|
||||
colSpan,
|
||||
getColumnClassName,
|
||||
}: {
|
||||
props: DataTableViewProps<TData>
|
||||
rows: Row<TData>[]
|
||||
colSpan: number
|
||||
getColumnClassName: DataTableColumnClassName
|
||||
}) {
|
||||
const headerHostRef = React.useRef<HTMLDivElement>(null)
|
||||
const bodyHostRef = React.useRef<HTMLDivElement>(null)
|
||||
const tableSizing = getTableSizing(props)
|
||||
|
||||
React.useEffect(() => {
|
||||
const headerScroller = headerHostRef.current?.querySelector<HTMLElement>(
|
||||
'[data-slot=table-container]'
|
||||
)
|
||||
const bodyScroller = bodyHostRef.current?.querySelector<HTMLElement>(
|
||||
'[data-slot=table-container]'
|
||||
)
|
||||
|
||||
if (!headerScroller || !bodyScroller) return
|
||||
|
||||
const syncHeaderScroll = () => {
|
||||
headerScroller.scrollLeft = bodyScroller.scrollLeft
|
||||
}
|
||||
|
||||
syncHeaderScroll()
|
||||
bodyScroller.addEventListener('scroll', syncHeaderScroll, { passive: true })
|
||||
|
||||
return () => {
|
||||
bodyScroller.removeEventListener('scroll', syncHeaderScroll)
|
||||
}
|
||||
}, [rows.length, props.tableClassName, props.colgroup])
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex h-full min-h-0 flex-col',
|
||||
props.tableContainerClassName
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'flex min-h-0 flex-1 flex-col overflow-hidden',
|
||||
props.splitHeaderScrollClassName
|
||||
)}
|
||||
>
|
||||
<div
|
||||
ref={headerHostRef}
|
||||
className='[scrollbar-gutter:stable] overflow-hidden [&_[data-slot=table-container]]:overflow-x-hidden'
|
||||
>
|
||||
<Table className={props.tableClassName} style={tableSizing.style}>
|
||||
{tableSizing.colgroup}
|
||||
<DataTableHeader
|
||||
table={props.table}
|
||||
applyHeaderSize={props.applyHeaderSize}
|
||||
className={props.tableHeaderClassName}
|
||||
rowClassName={props.tableHeaderRowClassName}
|
||||
getColumnClassName={getColumnClassName}
|
||||
/>
|
||||
</Table>
|
||||
</div>
|
||||
<div
|
||||
ref={bodyHostRef}
|
||||
className={cn(
|
||||
'min-h-0 flex-1 [scrollbar-gutter:stable] overflow-y-auto',
|
||||
props.bodyContainerClassName
|
||||
)}
|
||||
>
|
||||
<Table className={props.tableClassName} style={tableSizing.style}>
|
||||
{tableSizing.colgroup}
|
||||
{renderTableBody(props, rows, colSpan, getColumnClassName)}
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function useResolvedColumnClassName(
|
||||
getColumnClassName?: DataTableColumnClassName,
|
||||
pinnedColumns?: DataTablePinnedColumn[]
|
||||
) {
|
||||
const pinnedColumnById = React.useMemo(
|
||||
() => getPinnedColumnMap(pinnedColumns),
|
||||
[pinnedColumns]
|
||||
)
|
||||
|
||||
return React.useMemo(
|
||||
() =>
|
||||
getResolvedColumnClassNameFromMap(getColumnClassName, pinnedColumnById),
|
||||
[getColumnClassName, pinnedColumnById]
|
||||
)
|
||||
}
|
||||
|
||||
function getTableSizing<TData>(props: DataTableViewProps<TData>): {
|
||||
colgroup?: React.ReactNode
|
||||
style?: React.CSSProperties
|
||||
} {
|
||||
if (props.colgroup) {
|
||||
return { colgroup: props.colgroup }
|
||||
}
|
||||
|
||||
if (!props.splitHeader && !props.applyHeaderSize) {
|
||||
return {}
|
||||
}
|
||||
|
||||
return {
|
||||
colgroup: <DataTableColgroup table={props.table} />,
|
||||
style: getTableSizeStyle(props.table),
|
||||
}
|
||||
}
|
||||
|
||||
function renderTableBody<TData>(
|
||||
props: DataTableViewProps<TData>,
|
||||
rows: Row<TData>[],
|
||||
colSpan: number,
|
||||
getColumnClassName: DataTableColumnClassName
|
||||
) {
|
||||
return (
|
||||
<TableBody className={props.tableBodyClassName}>
|
||||
{renderTableBodyContent(props, rows, colSpan, getColumnClassName)}
|
||||
</TableBody>
|
||||
)
|
||||
}
|
||||
|
||||
function renderTableBodyContent<TData>(
|
||||
props: DataTableViewProps<TData>,
|
||||
rows: Row<TData>[],
|
||||
colSpan: number,
|
||||
getColumnClassName: DataTableColumnClassName
|
||||
) {
|
||||
if (props.isLoading) {
|
||||
return (
|
||||
<TableSkeleton
|
||||
table={props.table}
|
||||
keyPrefix={props.skeletonKeyPrefix}
|
||||
rowHeight={props.skeletonRowHeight}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (rows.length === 0) {
|
||||
return renderEmptyState(props, colSpan)
|
||||
}
|
||||
|
||||
return rows.map((row) =>
|
||||
props.renderRow
|
||||
? props.renderRow(row, {
|
||||
getCellClassName: (columnId, className) =>
|
||||
cn(getColumnClassName(columnId, 'cell'), className),
|
||||
})
|
||||
: renderDefaultRow(props, row, getColumnClassName)
|
||||
)
|
||||
}
|
||||
|
||||
function renderEmptyState<TData>(
|
||||
props: DataTableViewProps<TData>,
|
||||
colSpan: number
|
||||
) {
|
||||
if (props.emptyContent) {
|
||||
return (
|
||||
<TableRow>
|
||||
<TableCell colSpan={colSpan} className={props.emptyCellClassName}>
|
||||
{props.emptyContent}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<TableEmpty
|
||||
colSpan={colSpan}
|
||||
title={props.emptyTitle}
|
||||
description={props.emptyDescription}
|
||||
icon={props.emptyIcon}
|
||||
>
|
||||
{props.emptyAction}
|
||||
</TableEmpty>
|
||||
)
|
||||
}
|
||||
|
||||
function renderDefaultRow<TData>(
|
||||
props: DataTableViewProps<TData>,
|
||||
row: Row<TData>,
|
||||
getColumnClassName: DataTableColumnClassName
|
||||
) {
|
||||
return (
|
||||
<DataTableRow
|
||||
key={row.id}
|
||||
row={row}
|
||||
className={cn(props.tableBodyRowClassName, props.getRowClassName?.(row))}
|
||||
getColumnClassName={getColumnClassName}
|
||||
/>
|
||||
)
|
||||
}
|
||||
+44
-40
@@ -39,48 +39,55 @@ type DataTablePaginationProps<TData> = {
|
||||
table: Table<TData>
|
||||
}
|
||||
|
||||
const PAGE_SIZE_OPTIONS = [10, 20, 30, 40, 50, 100] as const
|
||||
const PAGE_SIZE_SELECT_ITEMS = PAGE_SIZE_OPTIONS.map((pageSize) => ({
|
||||
value: `${pageSize}`,
|
||||
label: pageSize,
|
||||
}))
|
||||
|
||||
export function DataTablePagination<TData>({
|
||||
table,
|
||||
}: DataTablePaginationProps<TData>) {
|
||||
const { t } = useTranslation()
|
||||
const currentPage = table.getState().pagination.pageIndex + 1
|
||||
const pagination = table.getState().pagination
|
||||
const currentPage = pagination.pageIndex + 1
|
||||
const pageSize = pagination.pageSize
|
||||
const totalPages = table.getPageCount()
|
||||
const totalRows = table.getRowCount()
|
||||
const pageNumbers = getPageNumbers(currentPage, totalPages)
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-between overflow-clip',
|
||||
'@max-2xl/content:flex-col-reverse @max-2xl/content:gap-2 sm:@max-2xl/content:gap-4'
|
||||
'@container/pagination flex min-w-0 items-center justify-end overflow-clip'
|
||||
)}
|
||||
style={{ overflowClipMargin: 1 }}
|
||||
>
|
||||
<div className='flex w-full items-center justify-between gap-2'>
|
||||
<div className='flex min-w-0 items-center text-xs font-medium whitespace-nowrap sm:min-w-[130px] sm:text-sm @2xl/content:hidden'>
|
||||
{t('Page {{current}} of {{total}}', {
|
||||
current: currentPage,
|
||||
total: totalPages,
|
||||
})}
|
||||
<div className='flex min-w-0 shrink-0 items-center gap-2 @xl/pagination:gap-3'>
|
||||
<div className='flex shrink-0 items-baseline gap-1.5 text-xs font-medium whitespace-nowrap sm:text-sm'>
|
||||
<span className='text-muted-foreground/80'>{t('Total:')}</span>
|
||||
<span className='text-foreground tabular-nums'>
|
||||
{totalRows.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
<div className='flex items-center gap-2 @max-2xl/content:flex-row-reverse'>
|
||||
|
||||
<div className='flex shrink-0 items-center gap-1.5 @lg/pagination:gap-2'>
|
||||
<p className='text-muted-foreground/80 hidden text-sm font-medium whitespace-nowrap @2xl/pagination:block'>
|
||||
{t('Rows per page')}
|
||||
</p>
|
||||
<Select
|
||||
items={[
|
||||
...[10, 20, 30, 40, 50, 100].map((pageSize) => ({
|
||||
value: `${pageSize}`,
|
||||
label: pageSize,
|
||||
})),
|
||||
]}
|
||||
value={`${table.getState().pagination.pageSize}`}
|
||||
items={PAGE_SIZE_SELECT_ITEMS}
|
||||
value={`${pageSize}`}
|
||||
onValueChange={(value) => {
|
||||
table.setPageSize(Number(value))
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className='h-8 w-[64px] sm:w-[70px]'>
|
||||
<SelectValue placeholder={table.getState().pagination.pageSize} />
|
||||
<SelectTrigger className='text-foreground h-8 w-[64px] font-medium tabular-nums sm:w-[70px]'>
|
||||
<SelectValue placeholder={pageSize} />
|
||||
</SelectTrigger>
|
||||
<SelectContent side='top' alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
{[10, 20, 30, 40, 50, 100].map((pageSize) => (
|
||||
{PAGE_SIZE_OPTIONS.map((pageSize) => (
|
||||
<SelectItem key={pageSize} value={`${pageSize}`}>
|
||||
{pageSize}
|
||||
</SelectItem>
|
||||
@@ -88,23 +95,12 @@ export function DataTablePagination<TData>({
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className='hidden text-sm font-medium sm:block'>
|
||||
{t('Rows per page')}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex items-center sm:space-x-6 lg:space-x-8'>
|
||||
<div className='flex min-w-[130px] items-center text-sm font-medium whitespace-nowrap @max-3xl/content:hidden'>
|
||||
{t('Page {{current}} of {{total}}', {
|
||||
current: currentPage,
|
||||
total: totalPages,
|
||||
})}
|
||||
</div>
|
||||
<div className='flex items-center space-x-1.5 sm:space-x-2'>
|
||||
<div className='flex min-w-0 shrink-0 items-center gap-1 @lg/pagination:gap-1.5 @xl/pagination:gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
className='size-8 p-0 @max-md/content:hidden'
|
||||
className='text-muted-foreground hover:text-foreground disabled:text-muted-foreground/50 size-8 p-0 @max-lg/pagination:hidden'
|
||||
onClick={() => table.setPageIndex(0)}
|
||||
disabled={!table.getCanPreviousPage()}
|
||||
>
|
||||
@@ -113,7 +109,7 @@ export function DataTablePagination<TData>({
|
||||
</Button>
|
||||
<Button
|
||||
variant='outline'
|
||||
className='size-8 p-0'
|
||||
className='text-muted-foreground hover:text-foreground disabled:text-muted-foreground/50 size-8 p-0'
|
||||
onClick={() => table.previousPage()}
|
||||
disabled={!table.getCanPreviousPage()}
|
||||
>
|
||||
@@ -121,18 +117,26 @@ export function DataTablePagination<TData>({
|
||||
<ChevronLeftIcon className='h-4 w-4' />
|
||||
</Button>
|
||||
|
||||
{/* Page number buttons */}
|
||||
{pageNumbers.map((pageNumber, index) => (
|
||||
<div key={`${pageNumber}-${index}`} className='flex items-center'>
|
||||
{pageNumber === '...' ? (
|
||||
<span className='text-muted-foreground px-1 text-sm'>...</span>
|
||||
<span className='text-muted-foreground/60 px-0.5 text-sm @lg/pagination:px-1'>
|
||||
...
|
||||
</span>
|
||||
) : (
|
||||
<Button
|
||||
variant={currentPage === pageNumber ? 'default' : 'outline'}
|
||||
className='h-8 min-w-8 px-2'
|
||||
className={cn(
|
||||
'h-8 min-w-8 px-2 tabular-nums',
|
||||
currentPage === pageNumber
|
||||
? 'font-semibold'
|
||||
: 'text-muted-foreground hover:text-foreground'
|
||||
)}
|
||||
onClick={() => table.setPageIndex((pageNumber as number) - 1)}
|
||||
>
|
||||
<span className='sr-only'>Go to page {pageNumber}</span>
|
||||
<span className='sr-only'>
|
||||
{t('Go to page {{page}}', { page: pageNumber })}
|
||||
</span>
|
||||
{pageNumber}
|
||||
</Button>
|
||||
)}
|
||||
@@ -141,7 +145,7 @@ export function DataTablePagination<TData>({
|
||||
|
||||
<Button
|
||||
variant='outline'
|
||||
className='size-8 p-0'
|
||||
className='text-muted-foreground hover:text-foreground disabled:text-muted-foreground/50 size-8 p-0'
|
||||
onClick={() => table.nextPage()}
|
||||
disabled={!table.getCanNextPage()}
|
||||
>
|
||||
@@ -150,7 +154,7 @@ export function DataTablePagination<TData>({
|
||||
</Button>
|
||||
<Button
|
||||
variant='outline'
|
||||
className='size-8 p-0 @max-md/content:hidden'
|
||||
className='text-muted-foreground hover:text-foreground disabled:text-muted-foreground/50 size-8 p-0 @max-lg/pagination:hidden'
|
||||
onClick={() => table.setPageIndex(table.getPageCount() - 1)}
|
||||
disabled={!table.getCanNextPage()}
|
||||
>
|
||||
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import type * as React from 'react'
|
||||
import type { Table as TanstackTable } from '@tanstack/react-table'
|
||||
|
||||
export function getTableSizeStyle<TData>(
|
||||
table: TanstackTable<TData>
|
||||
): React.CSSProperties {
|
||||
const width = table
|
||||
.getVisibleLeafColumns()
|
||||
.reduce((total, column) => total + column.getSize(), 0)
|
||||
|
||||
return { minWidth: width, tableLayout: 'fixed', width: '100%' }
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import type * as React from 'react'
|
||||
import type { Row, Table as TanstackTable } from '@tanstack/react-table'
|
||||
|
||||
export type DataTableColumnClassName = (
|
||||
columnId: string,
|
||||
kind: 'header' | 'cell'
|
||||
) => string | undefined
|
||||
|
||||
export type DataTablePinnedColumn = {
|
||||
columnId: string
|
||||
side: 'left' | 'right'
|
||||
className?: string
|
||||
headerClassName?: string
|
||||
cellClassName?: string
|
||||
}
|
||||
|
||||
export type DataTableRenderRowHelpers = {
|
||||
getCellClassName: (columnId: string, className?: string) => string | undefined
|
||||
}
|
||||
|
||||
export type DataTableViewProps<TData> = {
|
||||
table: TanstackTable<TData>
|
||||
isLoading?: boolean
|
||||
rows?: Row<TData>[]
|
||||
emptyTitle?: string
|
||||
emptyDescription?: string
|
||||
emptyIcon?: React.ReactNode
|
||||
emptyAction?: React.ReactNode
|
||||
emptyContent?: React.ReactNode
|
||||
emptyCellClassName?: string
|
||||
skeletonKeyPrefix?: string
|
||||
skeletonRowHeight?: string
|
||||
renderRow?: (
|
||||
row: Row<TData>,
|
||||
helpers: DataTableRenderRowHelpers
|
||||
) => React.ReactNode
|
||||
getRowClassName?: (row: Row<TData>) => string | undefined
|
||||
getColumnClassName?: DataTableColumnClassName
|
||||
pinnedColumns?: DataTablePinnedColumn[]
|
||||
applyHeaderSize?: boolean
|
||||
tableClassName?: string
|
||||
tableHeaderClassName?: string
|
||||
tableHeaderRowClassName?: string
|
||||
tableBodyClassName?: string
|
||||
tableBodyRowClassName?: string
|
||||
splitHeader?: boolean
|
||||
splitHeaderScrollClassName?: string
|
||||
bodyContainerClassName?: string
|
||||
containerClassName?: string
|
||||
containerProps?: Omit<React.ComponentProps<'div'>, 'className' | 'children'>
|
||||
tableContainerClassName?: string
|
||||
colgroup?: React.ReactNode
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import {
|
||||
type ColumnDef,
|
||||
type ColumnFiltersState,
|
||||
type ExpandedState,
|
||||
type OnChangeFn,
|
||||
type PaginationState,
|
||||
type RowSelectionState,
|
||||
type SortingState,
|
||||
type TableOptions,
|
||||
type Updater,
|
||||
type VisibilityState,
|
||||
getCoreRowModel,
|
||||
getExpandedRowModel,
|
||||
getFacetedRowModel,
|
||||
getFacetedUniqueValues,
|
||||
getFilteredRowModel,
|
||||
getPaginationRowModel,
|
||||
getSortedRowModel,
|
||||
useReactTable,
|
||||
} from '@tanstack/react-table'
|
||||
|
||||
type DataTableFeatureOptions<TData> = Pick<
|
||||
TableOptions<TData>,
|
||||
| 'enableRowSelection'
|
||||
| 'getRowId'
|
||||
| 'getSubRows'
|
||||
| 'globalFilterFn'
|
||||
| 'autoResetPageIndex'
|
||||
| 'manualFiltering'
|
||||
| 'manualPagination'
|
||||
| 'manualSorting'
|
||||
>
|
||||
|
||||
type DataTableStateOptions = {
|
||||
initialSorting?: SortingState
|
||||
sorting?: SortingState
|
||||
onSortingChange?: OnChangeFn<SortingState>
|
||||
initialColumnVisibility?: VisibilityState
|
||||
columnVisibility?: VisibilityState
|
||||
onColumnVisibilityChange?: OnChangeFn<VisibilityState>
|
||||
initialRowSelection?: RowSelectionState
|
||||
rowSelection?: RowSelectionState
|
||||
onRowSelectionChange?: OnChangeFn<RowSelectionState>
|
||||
initialExpanded?: ExpandedState
|
||||
expanded?: ExpandedState
|
||||
onExpandedChange?: OnChangeFn<ExpandedState>
|
||||
columnFilters?: ColumnFiltersState
|
||||
onColumnFiltersChange?: OnChangeFn<ColumnFiltersState>
|
||||
globalFilter?: string
|
||||
onGlobalFilterChange?: OnChangeFn<string>
|
||||
initialPagination?: PaginationState
|
||||
pagination?: PaginationState
|
||||
onPaginationChange?: OnChangeFn<PaginationState>
|
||||
}
|
||||
|
||||
type DataTableRowModelOptions = {
|
||||
withFilteredRowModel?: boolean
|
||||
withPaginationRowModel?: boolean
|
||||
withSortedRowModel?: boolean
|
||||
withFacetedRowModel?: boolean
|
||||
withExpandedRowModel?: boolean
|
||||
}
|
||||
|
||||
type UseDataTableOptions<TData> = DataTableFeatureOptions<TData> &
|
||||
DataTableStateOptions &
|
||||
DataTableRowModelOptions & {
|
||||
data: TData[]
|
||||
columns: ColumnDef<TData, unknown>[]
|
||||
totalCount?: number
|
||||
pageCount?: number
|
||||
ensurePageInRange?: (pageCount: number) => void
|
||||
}
|
||||
|
||||
function resolveUpdater<TValue>(
|
||||
updater: Updater<TValue>,
|
||||
previous: TValue
|
||||
): TValue {
|
||||
return typeof updater === 'function'
|
||||
? (updater as (old: TValue) => TValue)(previous)
|
||||
: updater
|
||||
}
|
||||
|
||||
function useControllableTableState<TValue>(
|
||||
controlledValue: TValue | undefined,
|
||||
defaultValue: TValue,
|
||||
onChange: OnChangeFn<TValue> | undefined
|
||||
): [TValue, OnChangeFn<TValue>] {
|
||||
const [uncontrolledValue, setUncontrolledValue] =
|
||||
React.useState<TValue>(defaultValue)
|
||||
|
||||
const value = controlledValue ?? uncontrolledValue
|
||||
|
||||
const setValue = React.useCallback<OnChangeFn<TValue>>(
|
||||
(updater) => {
|
||||
if (controlledValue === undefined) {
|
||||
setUncontrolledValue((previous) => resolveUpdater(updater, previous))
|
||||
}
|
||||
onChange?.(updater)
|
||||
},
|
||||
[controlledValue, onChange]
|
||||
)
|
||||
|
||||
return [value, setValue]
|
||||
}
|
||||
|
||||
export function useDataTable<TData>(options: UseDataTableOptions<TData>) {
|
||||
const {
|
||||
data,
|
||||
columns,
|
||||
totalCount,
|
||||
pageCount: explicitPageCount,
|
||||
ensurePageInRange,
|
||||
manualFiltering,
|
||||
manualPagination,
|
||||
manualSorting,
|
||||
initialSorting = [],
|
||||
initialColumnVisibility = {},
|
||||
initialRowSelection = {},
|
||||
initialExpanded = {},
|
||||
initialPagination = { pageIndex: 0, pageSize: 20 },
|
||||
withFilteredRowModel = !manualFiltering,
|
||||
withPaginationRowModel = !manualPagination,
|
||||
withSortedRowModel = !manualSorting,
|
||||
withFacetedRowModel = !manualFiltering,
|
||||
withExpandedRowModel = false,
|
||||
} = options
|
||||
|
||||
const [sorting, onSortingChange] = useControllableTableState(
|
||||
options.sorting,
|
||||
initialSorting,
|
||||
options.onSortingChange
|
||||
)
|
||||
const [columnVisibility, onColumnVisibilityChange] =
|
||||
useControllableTableState(
|
||||
options.columnVisibility,
|
||||
initialColumnVisibility,
|
||||
options.onColumnVisibilityChange
|
||||
)
|
||||
const [rowSelection, onRowSelectionChange] = useControllableTableState(
|
||||
options.rowSelection,
|
||||
initialRowSelection,
|
||||
options.onRowSelectionChange
|
||||
)
|
||||
const [expanded, onExpandedChange] = useControllableTableState(
|
||||
options.expanded,
|
||||
initialExpanded,
|
||||
options.onExpandedChange
|
||||
)
|
||||
const [pagination, onPaginationChange] = useControllableTableState(
|
||||
options.pagination,
|
||||
initialPagination,
|
||||
options.onPaginationChange
|
||||
)
|
||||
|
||||
const resolvedPageCount =
|
||||
explicitPageCount ??
|
||||
(totalCount !== undefined
|
||||
? Math.ceil(totalCount / pagination.pageSize)
|
||||
: undefined)
|
||||
|
||||
const table = useReactTable({
|
||||
data,
|
||||
columns,
|
||||
rowCount: totalCount,
|
||||
pageCount: resolvedPageCount,
|
||||
state: {
|
||||
sorting,
|
||||
columnVisibility,
|
||||
rowSelection,
|
||||
expanded,
|
||||
columnFilters: options.columnFilters,
|
||||
globalFilter: options.globalFilter,
|
||||
pagination,
|
||||
},
|
||||
enableRowSelection: options.enableRowSelection,
|
||||
getRowId: options.getRowId,
|
||||
getSubRows: options.getSubRows,
|
||||
globalFilterFn: options.globalFilterFn,
|
||||
autoResetPageIndex: options.autoResetPageIndex,
|
||||
manualFiltering,
|
||||
manualPagination,
|
||||
manualSorting,
|
||||
onSortingChange,
|
||||
onColumnVisibilityChange,
|
||||
onRowSelectionChange,
|
||||
onExpandedChange,
|
||||
onColumnFiltersChange: options.onColumnFiltersChange,
|
||||
onGlobalFilterChange: options.onGlobalFilterChange,
|
||||
onPaginationChange,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getFilteredRowModel: withFilteredRowModel
|
||||
? getFilteredRowModel()
|
||||
: undefined,
|
||||
getPaginationRowModel: withPaginationRowModel
|
||||
? getPaginationRowModel()
|
||||
: undefined,
|
||||
getSortedRowModel: withSortedRowModel ? getSortedRowModel() : undefined,
|
||||
getFacetedRowModel: withFacetedRowModel ? getFacetedRowModel() : undefined,
|
||||
getFacetedUniqueValues: withFacetedRowModel
|
||||
? getFacetedUniqueValues()
|
||||
: undefined,
|
||||
getExpandedRowModel: withExpandedRowModel
|
||||
? getExpandedRowModel()
|
||||
: undefined,
|
||||
})
|
||||
|
||||
const actualPageCount = table.getPageCount()
|
||||
React.useEffect(() => {
|
||||
ensurePageInRange?.(actualPageCount)
|
||||
}, [actualPageCount, ensurePageInRange])
|
||||
|
||||
return {
|
||||
table,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import type { ColumnFiltersState, OnChangeFn } from '@tanstack/react-table'
|
||||
import { useDebounce } from '@/hooks/use-debounce'
|
||||
|
||||
type UseDebouncedColumnFilterOptions = {
|
||||
columnFilters: ColumnFiltersState
|
||||
columnId: string
|
||||
onColumnFiltersChange: OnChangeFn<ColumnFiltersState>
|
||||
delay?: number
|
||||
}
|
||||
|
||||
export function useDebouncedColumnFilter({
|
||||
columnFilters,
|
||||
columnId,
|
||||
onColumnFiltersChange,
|
||||
delay = 500,
|
||||
}: UseDebouncedColumnFilterOptions) {
|
||||
const value =
|
||||
(columnFilters.find((filter) => filter.id === columnId)?.value as
|
||||
| string
|
||||
| undefined) ?? ''
|
||||
const [inputValue, setInputValue] = React.useState(value)
|
||||
const [pendingValue, setPendingValue] = React.useState(value)
|
||||
const isComposingRef = React.useRef(false)
|
||||
const debouncedValue = useDebounce(pendingValue, delay)
|
||||
|
||||
React.useEffect(() => {
|
||||
// Keep the input aligned when URL state changes outside the local field.
|
||||
if (!isComposingRef.current) {
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setInputValue(value)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setPendingValue(value)
|
||||
}, [value])
|
||||
|
||||
React.useEffect(() => {
|
||||
if (debouncedValue === value) return
|
||||
|
||||
onColumnFiltersChange((previous) => {
|
||||
const filters = previous.filter((filter) => filter.id !== columnId)
|
||||
return debouncedValue
|
||||
? [...filters, { id: columnId, value: debouncedValue }]
|
||||
: filters
|
||||
})
|
||||
}, [columnId, debouncedValue, onColumnFiltersChange, value])
|
||||
|
||||
const updateInputValue = React.useCallback((nextValue: string) => {
|
||||
setInputValue(nextValue)
|
||||
|
||||
if (!isComposingRef.current) {
|
||||
setPendingValue(nextValue)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleChange = React.useCallback(
|
||||
(event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
updateInputValue(event.target.value)
|
||||
},
|
||||
[updateInputValue]
|
||||
)
|
||||
|
||||
const handleCompositionStart = React.useCallback(() => {
|
||||
isComposingRef.current = true
|
||||
}, [])
|
||||
|
||||
const handleCompositionEnd = React.useCallback(
|
||||
(event: React.CompositionEvent<HTMLInputElement>) => {
|
||||
isComposingRef.current = false
|
||||
const nextValue = event.currentTarget.value
|
||||
setInputValue(nextValue)
|
||||
setPendingValue(nextValue)
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const resetInput = React.useCallback(() => {
|
||||
isComposingRef.current = false
|
||||
setInputValue('')
|
||||
setPendingValue('')
|
||||
}, [])
|
||||
|
||||
return {
|
||||
value,
|
||||
inputValue,
|
||||
setInputValue: updateInputValue,
|
||||
onChange: handleChange,
|
||||
onCompositionStart: handleCompositionStart,
|
||||
onCompositionEnd: handleCompositionEnd,
|
||||
resetInput,
|
||||
}
|
||||
}
|
||||
+24
-10
@@ -16,16 +16,30 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
export { DataTablePagination } from './pagination'
|
||||
export { DataTableColumnHeader } from './column-header'
|
||||
export { DataTableFacetedFilter } from './faceted-filter'
|
||||
export { DataTableViewOptions } from './view-options'
|
||||
export { DataTableToolbar } from './toolbar'
|
||||
export { DataTableBulkActions } from './bulk-actions'
|
||||
export { TableSkeleton } from './table-skeleton'
|
||||
export { TableEmpty } from './table-empty'
|
||||
export { MobileCardList } from './mobile-card-list'
|
||||
export { DataTablePage, type DataTablePageProps } from './data-table-page'
|
||||
export { DataTablePagination } from './core/pagination'
|
||||
export { DataTableColumnHeader } from './core/column-header'
|
||||
export { DataTableViewOptions } from './toolbar/view-options'
|
||||
export { DataTableToolbar } from './toolbar/toolbar'
|
||||
export { DataTableBulkActions } from './toolbar/bulk-actions'
|
||||
export {
|
||||
StaticDataTable,
|
||||
type StaticDataTableColumn,
|
||||
} from './static/static-data-table'
|
||||
export { staticDataTableClassNames } from './static/static-data-table-classnames'
|
||||
export {
|
||||
DataTableRow,
|
||||
DataTableView,
|
||||
type DataTableColumnClassName,
|
||||
type DataTablePinnedColumn,
|
||||
type DataTableRenderRowHelpers,
|
||||
} from './core/data-table-view'
|
||||
export { MobileCardList } from './layout/mobile-card-list'
|
||||
export {
|
||||
DataTablePage,
|
||||
type DataTablePageProps,
|
||||
} from './layout/data-table-page'
|
||||
export { useDataTable } from './hooks/use-data-table'
|
||||
export { useDebouncedColumnFilter } from './hooks/use-debounced-column-filter'
|
||||
|
||||
export const DISABLED_ROW_DESKTOP =
|
||||
'bg-muted/85 hover:bg-muted [&>td:first-child]:border-l-muted-foreground/35 [&>td:first-child]:border-l-4 [&>td:first-child]:pl-1'
|
||||
|
||||
+85
-112
@@ -18,27 +18,22 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import {
|
||||
flexRender,
|
||||
type ColumnDef,
|
||||
type Row,
|
||||
type Table as TanstackTable,
|
||||
} from '@tanstack/react-table'
|
||||
import { useMediaQuery } from '@/hooks'
|
||||
import { cn } from '@/lib/utils'
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from '@/components/ui/table'
|
||||
import { PageFooterPortal } from '@/components/layout'
|
||||
import {
|
||||
DataTableView,
|
||||
type DataTableColumnClassName,
|
||||
type DataTablePinnedColumn,
|
||||
type DataTableRenderRowHelpers,
|
||||
} from '../core/data-table-view'
|
||||
import { DataTablePagination } from '../core/pagination'
|
||||
import { DataTableToolbar } from '../toolbar/toolbar'
|
||||
import { MobileCardList } from './mobile-card-list'
|
||||
import { DataTablePagination } from './pagination'
|
||||
import { TableEmpty } from './table-empty'
|
||||
import { TableSkeleton } from './table-skeleton'
|
||||
import { DataTableToolbar } from './toolbar'
|
||||
|
||||
/**
|
||||
* Pass-through configuration for the default {@link DataTableToolbar}.
|
||||
@@ -145,7 +140,22 @@ export type DataTablePageProps<TData> = {
|
||||
* Custom desktop row renderer — replaces the default `<TableRow>`/`<TableCell>` mapping.
|
||||
* Use for expanded rows, aggregate rows, click-on-row navigation, etc.
|
||||
*/
|
||||
renderRow?: (row: Row<TData>) => React.ReactNode
|
||||
renderRow?: (
|
||||
row: Row<TData>,
|
||||
helpers: DataTableRenderRowHelpers
|
||||
) => React.ReactNode
|
||||
|
||||
/**
|
||||
* Desktop column className resolver. Use for semantic alignment/spacing only;
|
||||
* fixed-column behavior should be configured with `pinnedColumns`.
|
||||
*/
|
||||
getColumnClassName?: DataTableColumnClassName
|
||||
|
||||
/**
|
||||
* Fixed desktop columns. The shared table component owns sticky position,
|
||||
* layering, shadows, and row-state backgrounds.
|
||||
*/
|
||||
pinnedColumns?: DataTablePinnedColumn[]
|
||||
|
||||
/**
|
||||
* Apply explicit column widths from `header.getSize()` to `<TableHead>`.
|
||||
@@ -182,6 +192,12 @@ export type DataTablePageProps<TData> = {
|
||||
*/
|
||||
className?: string
|
||||
|
||||
/**
|
||||
* Make the desktop table consume the available page height and scroll inside
|
||||
* the table body while keeping the header fixed. Defaults to `true`.
|
||||
*/
|
||||
fixedHeight?: boolean
|
||||
|
||||
/**
|
||||
* Desktop table container className (the bordered scroll wrapper).
|
||||
*/
|
||||
@@ -189,7 +205,8 @@ export type DataTablePageProps<TData> = {
|
||||
|
||||
/**
|
||||
* Desktop `<TableHeader>` className override.
|
||||
* Useful for sticky headers (`'sticky top-0 z-10 bg-muted/30'`) on long lists.
|
||||
* Use for header color/spacing overrides. Fixed-height pages keep the header
|
||||
* outside the scrollable body automatically.
|
||||
*/
|
||||
tableHeaderClassName?: string
|
||||
}
|
||||
@@ -222,10 +239,18 @@ export function DataTablePage<TData>(props: DataTablePageProps<TData>) {
|
||||
const toolbarNode = renderToolbar(props)
|
||||
const mobileNode = renderMobile(props, showMobile)
|
||||
const desktopNode = renderDesktop(props, showMobile)
|
||||
const paginationNode = renderPagination(props)
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className={cn('space-y-2.5 sm:space-y-3', props.className)}>
|
||||
<div
|
||||
className={cn(
|
||||
props.fixedHeight !== false
|
||||
? 'flex h-full min-h-0 flex-col gap-2.5 sm:gap-3'
|
||||
: 'space-y-2.5 sm:space-y-3',
|
||||
props.className
|
||||
)}
|
||||
>
|
||||
{toolbarNode}
|
||||
{mobileNode}
|
||||
{desktopNode}
|
||||
@@ -236,16 +261,7 @@ export function DataTablePage<TData>(props: DataTablePageProps<TData>) {
|
||||
handle its own visibility, we just gate it to non-mobile. */}
|
||||
{!showMobile && props.bulkActions}
|
||||
|
||||
{props.showPagination !== false &&
|
||||
(props.paginationInFooter !== false ? (
|
||||
<PageFooterPortal>
|
||||
<DataTablePagination table={props.table} />
|
||||
</PageFooterPortal>
|
||||
) : (
|
||||
<div className='pt-2'>
|
||||
<DataTablePagination table={props.table} />
|
||||
</div>
|
||||
))}
|
||||
{paginationNode}
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -265,12 +281,25 @@ function renderToolbar<TData>(
|
||||
return null
|
||||
}
|
||||
|
||||
function renderPagination<TData>(
|
||||
props: DataTablePageProps<TData>
|
||||
): React.ReactNode {
|
||||
if (props.showPagination === false) return null
|
||||
|
||||
const pagination = <DataTablePagination table={props.table} />
|
||||
|
||||
return props.paginationInFooter !== false ? (
|
||||
<PageFooterPortal>{pagination}</PageFooterPortal>
|
||||
) : (
|
||||
<div className='pt-2'>{pagination}</div>
|
||||
)
|
||||
}
|
||||
|
||||
function renderMobile<TData>(
|
||||
props: DataTablePageProps<TData>,
|
||||
showMobile: boolean
|
||||
): React.ReactNode {
|
||||
if (!showMobile) return null
|
||||
if (props.mobile !== undefined) return props.mobile
|
||||
|
||||
const ownGetRowClassName = props.getRowClassName
|
||||
const mobileGetRowClassName =
|
||||
@@ -278,8 +307,7 @@ function renderMobile<TData>(
|
||||
(ownGetRowClassName
|
||||
? (row: Row<TData>) => ownGetRowClassName(row, { isMobile: true })
|
||||
: undefined)
|
||||
|
||||
return (
|
||||
const mobileContent = props.mobile ?? (
|
||||
<MobileCardList
|
||||
table={props.table}
|
||||
isLoading={props.isLoading}
|
||||
@@ -289,6 +317,8 @@ function renderMobile<TData>(
|
||||
getRowClassName={mobileGetRowClassName}
|
||||
/>
|
||||
)
|
||||
|
||||
return <div className='min-h-0 flex-1 overflow-y-auto'>{mobileContent}</div>
|
||||
}
|
||||
|
||||
function renderDesktop<TData>(
|
||||
@@ -297,94 +327,37 @@ function renderDesktop<TData>(
|
||||
): React.ReactNode {
|
||||
if (showMobile) return null
|
||||
|
||||
const rows = props.table.getRowModel().rows
|
||||
const isFetchingOnly = props.isFetching && !props.isLoading
|
||||
const fixedHeight = props.fixedHeight !== false
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'overflow-hidden rounded-lg border transition-opacity duration-150',
|
||||
<DataTableView
|
||||
table={props.table}
|
||||
isLoading={props.isLoading}
|
||||
emptyTitle={props.emptyTitle}
|
||||
emptyDescription={props.emptyDescription}
|
||||
emptyIcon={props.emptyIcon}
|
||||
emptyAction={props.emptyAction}
|
||||
skeletonKeyPrefix={props.skeletonKeyPrefix}
|
||||
renderRow={props.renderRow}
|
||||
applyHeaderSize={props.applyHeaderSize}
|
||||
splitHeader={fixedHeight}
|
||||
tableContainerClassName={fixedHeight ? 'h-full min-h-0' : undefined}
|
||||
tableHeaderClassName={cn(
|
||||
fixedHeight && 'bg-muted/30',
|
||||
props.tableHeaderClassName
|
||||
)}
|
||||
getColumnClassName={props.getColumnClassName}
|
||||
pinnedColumns={props.pinnedColumns}
|
||||
containerClassName={cn(
|
||||
fixedHeight && 'min-h-0 flex-1',
|
||||
'transition-opacity duration-150',
|
||||
isFetchingOnly && 'pointer-events-none opacity-60',
|
||||
props.tableClassName
|
||||
)}
|
||||
>
|
||||
<Table>
|
||||
<TableHeader className={props.tableHeaderClassName}>
|
||||
{props.table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => (
|
||||
<TableHead
|
||||
key={header.id}
|
||||
colSpan={header.colSpan}
|
||||
style={
|
||||
props.applyHeaderSize
|
||||
? { width: header.getSize() }
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{header.isPlaceholder
|
||||
? null
|
||||
: flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext()
|
||||
)}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{props.isLoading ? (
|
||||
<TableSkeleton
|
||||
table={props.table}
|
||||
keyPrefix={props.skeletonKeyPrefix}
|
||||
/>
|
||||
) : rows.length === 0 ? (
|
||||
<TableEmpty
|
||||
colSpan={props.columns.length}
|
||||
title={props.emptyTitle}
|
||||
description={props.emptyDescription}
|
||||
icon={props.emptyIcon}
|
||||
>
|
||||
{props.emptyAction}
|
||||
</TableEmpty>
|
||||
) : (
|
||||
rows.map((row) => {
|
||||
if (props.renderRow) {
|
||||
return props.renderRow(row)
|
||||
}
|
||||
return (
|
||||
<DefaultRow
|
||||
key={row.id}
|
||||
row={row}
|
||||
className={props.getRowClassName?.(row, { isMobile: false })}
|
||||
/>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function DefaultRow<TData>({
|
||||
row,
|
||||
className,
|
||||
}: {
|
||||
row: Row<TData>
|
||||
className?: string
|
||||
}) {
|
||||
return (
|
||||
<TableRow
|
||||
data-state={row.getIsSelected() && 'selected'}
|
||||
className={className}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell key={cell.id}>
|
||||
{flexRender(cell.column.columnDef.cell, cell.getContext())}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
getRowClassName={(row) =>
|
||||
props.getRowClassName?.(row, { isMobile: false })
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
export const staticDataTableClassNames = {
|
||||
container: 'overflow-hidden rounded-md border',
|
||||
sectionContainer: 'border-border/60 rounded-lg',
|
||||
embeddedContainer: 'rounded-none border-0',
|
||||
compactTable: 'text-sm',
|
||||
compactHeaderRow: 'hover:bg-transparent',
|
||||
mutedHeaderRow: 'bg-muted/30 hover:bg-muted/30',
|
||||
compactHeaderCell:
|
||||
'text-muted-foreground py-2 text-[10px] font-medium tracking-wider uppercase',
|
||||
compactHeaderCellRight:
|
||||
'text-muted-foreground py-2 text-right text-[10px] font-medium tracking-wider uppercase',
|
||||
compactCell: 'py-2.5',
|
||||
compactTopCell: 'py-2.5 align-top',
|
||||
compactTopNumericCell: 'py-2.5 text-right align-top font-mono',
|
||||
compactMutedCell: 'text-muted-foreground py-2.5',
|
||||
compactMutedCodeCell: 'text-muted-foreground py-2.5 font-mono',
|
||||
compactNumericCell: 'py-2.5 text-right font-mono',
|
||||
compactMutedNumericCell: 'text-muted-foreground py-2.5 text-right font-mono',
|
||||
topCell: 'py-2 align-top',
|
||||
topMutedCell: 'text-muted-foreground py-2 align-top',
|
||||
codeCell: 'font-mono text-sm',
|
||||
mutedCell: 'text-muted-foreground text-sm',
|
||||
mutedCodeCell: 'text-muted-foreground font-mono text-sm',
|
||||
topNumericCell: 'py-2 text-right font-mono',
|
||||
mediumCell: 'font-medium',
|
||||
actionHeaderCell: 'text-right',
|
||||
actionCell: 'text-right',
|
||||
} as const
|
||||
@@ -0,0 +1,206 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from '@/components/ui/table'
|
||||
import { staticDataTableClassNames } from './static-data-table-classnames'
|
||||
|
||||
type StaticDataTableBaseProps = {
|
||||
className?: string
|
||||
tableClassName?: string
|
||||
containerProps?: Omit<React.ComponentProps<'div'>, 'className' | 'children'>
|
||||
tableProps?: Omit<
|
||||
React.ComponentProps<typeof Table>,
|
||||
'className' | 'children'
|
||||
>
|
||||
}
|
||||
|
||||
type StaticDataTableDataProps<TData = unknown> = StaticDataTableBaseProps & {
|
||||
columns: StaticDataTableColumn<TData>[]
|
||||
data: TData[]
|
||||
getRowKey?: (row: TData, index: number) => React.Key
|
||||
getRowClassName?: (row: TData, index: number) => string | undefined
|
||||
renderRow?: (row: TData, index: number) => React.ReactNode
|
||||
empty?: boolean
|
||||
emptyContent?: React.ReactNode
|
||||
emptyClassName?: string
|
||||
headerRowClassName?: string
|
||||
}
|
||||
|
||||
type StaticDataTableChildrenProps = StaticDataTableBaseProps & {
|
||||
children: React.ReactNode
|
||||
columns?: never
|
||||
data?: never
|
||||
}
|
||||
|
||||
type StaticDataTableProps<TData = unknown> =
|
||||
| StaticDataTableDataProps<TData>
|
||||
| StaticDataTableChildrenProps
|
||||
|
||||
export type StaticDataTableColumn<TData = unknown> = {
|
||||
id: string
|
||||
header: React.ReactNode
|
||||
className?: string
|
||||
cellClassName?: string | ((row: TData, index: number) => string | undefined)
|
||||
cell?: (row: TData, index: number) => React.ReactNode
|
||||
}
|
||||
|
||||
export function StaticDataTable<TData = unknown>(
|
||||
props: StaticDataTableProps<TData>
|
||||
) {
|
||||
const { className, tableClassName, containerProps, tableProps } = props
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(staticDataTableClassNames.container, className)}
|
||||
{...containerProps}
|
||||
>
|
||||
<Table className={tableClassName} {...tableProps}>
|
||||
{props.columns !== undefined ? (
|
||||
<StaticDataTableWithColumns {...props} />
|
||||
) : (
|
||||
props.children
|
||||
)}
|
||||
</Table>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function StaticDataTableWithColumns<TData>({
|
||||
columns,
|
||||
data,
|
||||
getRowKey,
|
||||
getRowClassName,
|
||||
renderRow,
|
||||
empty,
|
||||
emptyContent,
|
||||
emptyClassName,
|
||||
headerRowClassName,
|
||||
}: StaticDataTableDataProps<TData>) {
|
||||
const isEmpty = empty ?? (data !== undefined && data.length === 0)
|
||||
const bodyRows = data.map((row, index) => (
|
||||
<StaticDataTableRow
|
||||
key={getRowKey?.(row, index) ?? index}
|
||||
row={row}
|
||||
index={index}
|
||||
columns={columns}
|
||||
getRowClassName={getRowClassName}
|
||||
renderRow={renderRow}
|
||||
/>
|
||||
))
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableHeader>
|
||||
<TableRow className={headerRowClassName}>
|
||||
{columns.map((column) => (
|
||||
<TableHead key={column.id} className={column.className}>
|
||||
{column.header}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{isEmpty ? (
|
||||
<StaticDataTableEmptyRow
|
||||
colSpan={columns.length}
|
||||
className={emptyClassName}
|
||||
>
|
||||
{emptyContent}
|
||||
</StaticDataTableEmptyRow>
|
||||
) : (
|
||||
bodyRows
|
||||
)}
|
||||
</TableBody>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
type StaticDataTableRowProps<TData> = Required<
|
||||
Pick<StaticDataTableDataProps<TData>, 'columns'>
|
||||
> &
|
||||
Pick<StaticDataTableDataProps<TData>, 'getRowClassName' | 'renderRow'> & {
|
||||
row: TData
|
||||
index: number
|
||||
}
|
||||
|
||||
function StaticDataTableRow<TData>({
|
||||
row,
|
||||
index,
|
||||
columns,
|
||||
getRowClassName,
|
||||
renderRow,
|
||||
}: StaticDataTableRowProps<TData>) {
|
||||
if (renderRow) {
|
||||
return <>{renderRow(row, index)}</>
|
||||
}
|
||||
|
||||
return (
|
||||
<TableRow className={getRowClassName?.(row, index)}>
|
||||
{columns.map((column) => (
|
||||
<TableCell
|
||||
key={column.id}
|
||||
className={getStaticCellClassName(column, row, index)}
|
||||
>
|
||||
{column.cell?.(row, index)}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
)
|
||||
}
|
||||
|
||||
function getStaticCellClassName<TData>(
|
||||
column: StaticDataTableColumn<TData>,
|
||||
row: TData,
|
||||
index: number
|
||||
) {
|
||||
return typeof column.cellClassName === 'function'
|
||||
? column.cellClassName(row, index)
|
||||
: column.cellClassName
|
||||
}
|
||||
|
||||
type StaticDataTableEmptyRowProps = {
|
||||
colSpan: number
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
}
|
||||
|
||||
function StaticDataTableEmptyRow({
|
||||
colSpan,
|
||||
children,
|
||||
className,
|
||||
}: StaticDataTableEmptyRowProps) {
|
||||
return (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={colSpan}
|
||||
className={cn('h-24 text-center', className)}
|
||||
>
|
||||
{children}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)
|
||||
}
|
||||
+105
-11
@@ -21,6 +21,7 @@ import { useState, type ReactNode } from 'react'
|
||||
import { type Table } from '@tanstack/react-table'
|
||||
import { ChevronDown, Loader2, X as Cross2Icon } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounce } from '@/hooks'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
@@ -46,6 +47,10 @@ export type DataTableToolbarProps<TData> = {
|
||||
* Placeholder for the default search input. Defaults to `t('Filter...')`.
|
||||
*/
|
||||
searchPlaceholder?: string
|
||||
/**
|
||||
* Delay committing the default search input. Defaults to immediate updates.
|
||||
*/
|
||||
searchDebounceMs?: number
|
||||
/**
|
||||
* Column id to filter on. When provided, the search input filters
|
||||
* a specific column. When omitted, the search input updates the
|
||||
@@ -136,6 +141,8 @@ export type DataTableToolbarProps<TData> = {
|
||||
export function DataTableToolbar<TData>(props: DataTableToolbarProps<TData>) {
|
||||
const { t } = useTranslation()
|
||||
const [expanded, setExpanded] = useState(false)
|
||||
const isSearchComposingRef = React.useRef(false)
|
||||
const lastCommittedSearchValueRef = React.useRef('')
|
||||
|
||||
const filters = props.filters ?? []
|
||||
const hasExpandable = props.expandable != null
|
||||
@@ -147,26 +154,109 @@ export function DataTableToolbar<TData>(props: DataTableToolbarProps<TData>) {
|
||||
!!props.hasAdditionalFilters
|
||||
|
||||
const placeholder = props.searchPlaceholder ?? t('Filter...')
|
||||
const currentSearchValue = props.searchKey
|
||||
? ((props.table.getColumn(props.searchKey)?.getFilterValue() as string) ??
|
||||
'')
|
||||
: ((props.table.getState().globalFilter as string | undefined) ?? '')
|
||||
|
||||
const [searchValue, setSearchValue] = useState(currentSearchValue)
|
||||
const [pendingSearchValue, setPendingSearchValue] =
|
||||
useState(currentSearchValue)
|
||||
const searchDebounceMs = Math.max(0, props.searchDebounceMs ?? 0)
|
||||
const debouncedSearchValue = useDebounce(
|
||||
pendingSearchValue,
|
||||
searchDebounceMs
|
||||
)
|
||||
|
||||
React.useEffect(() => {
|
||||
lastCommittedSearchValueRef.current = currentSearchValue
|
||||
if (!isSearchComposingRef.current) {
|
||||
setSearchValue(currentSearchValue)
|
||||
}
|
||||
setPendingSearchValue(currentSearchValue)
|
||||
}, [currentSearchValue])
|
||||
|
||||
const commitSearchValue = React.useCallback(
|
||||
(value: string) => {
|
||||
if (value === lastCommittedSearchValueRef.current) {
|
||||
return
|
||||
}
|
||||
|
||||
lastCommittedSearchValueRef.current = value
|
||||
|
||||
if (props.searchKey) {
|
||||
props.table.getColumn(props.searchKey)?.setFilterValue(value)
|
||||
return
|
||||
}
|
||||
|
||||
props.table.setGlobalFilter(value)
|
||||
},
|
||||
[props.searchKey, props.table]
|
||||
)
|
||||
|
||||
React.useEffect(() => {
|
||||
if (
|
||||
searchDebounceMs <= 0 ||
|
||||
isSearchComposingRef.current ||
|
||||
debouncedSearchValue !== pendingSearchValue
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
commitSearchValue(debouncedSearchValue)
|
||||
}, [
|
||||
commitSearchValue,
|
||||
debouncedSearchValue,
|
||||
pendingSearchValue,
|
||||
searchDebounceMs,
|
||||
])
|
||||
|
||||
const queueSearchValue = (value: string) => {
|
||||
setPendingSearchValue(value)
|
||||
|
||||
if (searchDebounceMs <= 0) {
|
||||
commitSearchValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSearchChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = event.target.value
|
||||
setSearchValue(value)
|
||||
|
||||
if (!isSearchComposingRef.current) {
|
||||
queueSearchValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSearchCompositionStart = () => {
|
||||
isSearchComposingRef.current = true
|
||||
}
|
||||
|
||||
const handleSearchCompositionEnd = (
|
||||
event: React.CompositionEvent<HTMLInputElement>
|
||||
) => {
|
||||
isSearchComposingRef.current = false
|
||||
const value = event.currentTarget.value
|
||||
setSearchValue(value)
|
||||
queueSearchValue(value)
|
||||
}
|
||||
|
||||
const searchInput = props.searchKey ? (
|
||||
<Input
|
||||
placeholder={placeholder}
|
||||
value={
|
||||
(props.table.getColumn(props.searchKey)?.getFilterValue() as string) ??
|
||||
''
|
||||
}
|
||||
onChange={(event) =>
|
||||
props.table
|
||||
.getColumn(props.searchKey!)
|
||||
?.setFilterValue(event.target.value)
|
||||
}
|
||||
value={searchValue}
|
||||
onChange={handleSearchChange}
|
||||
onCompositionStart={handleSearchCompositionStart}
|
||||
onCompositionEnd={handleSearchCompositionEnd}
|
||||
className='w-full sm:w-[200px] lg:w-[240px]'
|
||||
/>
|
||||
) : (
|
||||
<Input
|
||||
placeholder={placeholder}
|
||||
value={props.table.getState().globalFilter ?? ''}
|
||||
onChange={(event) => props.table.setGlobalFilter(event.target.value)}
|
||||
value={searchValue}
|
||||
onChange={handleSearchChange}
|
||||
onCompositionStart={handleSearchCompositionStart}
|
||||
onCompositionEnd={handleSearchCompositionEnd}
|
||||
className='w-full sm:w-[200px] lg:w-[240px]'
|
||||
/>
|
||||
)
|
||||
@@ -186,6 +276,10 @@ export function DataTableToolbar<TData>(props: DataTableToolbarProps<TData>) {
|
||||
})
|
||||
|
||||
const handleReset = () => {
|
||||
isSearchComposingRef.current = false
|
||||
setSearchValue('')
|
||||
setPendingSearchValue('')
|
||||
lastCommittedSearchValueRef.current = ''
|
||||
props.table.resetColumnFilters()
|
||||
props.table.setGlobalFilter('')
|
||||
props.onReset?.()
|
||||
+127
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import {
|
||||
Dialog as DialogRoot,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
|
||||
type DialogProps = React.ComponentProps<typeof DialogRoot> & {
|
||||
title: React.ReactNode
|
||||
description?: React.ReactNode
|
||||
children: React.ReactNode
|
||||
trigger?: React.ReactElement
|
||||
footer?: React.ReactNode
|
||||
contentHeight?: React.CSSProperties['height']
|
||||
contentClassName?: string
|
||||
headerClassName?: string
|
||||
titleClassName?: string
|
||||
descriptionClassName?: string
|
||||
bodyClassName?: string
|
||||
footerClassName?: string
|
||||
initialFocus?: boolean
|
||||
showCloseButton?: boolean
|
||||
}
|
||||
|
||||
const dialogContentMotionClassName =
|
||||
'data-open:animate-in data-open:fade-in-0 data-open:zoom-in-95 data-closed:animate-out data-closed:fade-out-0 data-closed:zoom-out-95 duration-100'
|
||||
|
||||
export function Dialog({
|
||||
title,
|
||||
description,
|
||||
children,
|
||||
trigger,
|
||||
footer,
|
||||
contentHeight = 'auto',
|
||||
contentClassName,
|
||||
headerClassName,
|
||||
titleClassName,
|
||||
descriptionClassName,
|
||||
bodyClassName,
|
||||
footerClassName,
|
||||
initialFocus,
|
||||
showCloseButton,
|
||||
...dialogProps
|
||||
}: DialogProps) {
|
||||
return (
|
||||
<DialogRoot {...dialogProps}>
|
||||
{trigger ? <DialogTrigger render={trigger} /> : null}
|
||||
<DialogContent
|
||||
className={cn(
|
||||
'flex max-h-[calc(100vh-2rem)] w-full flex-col gap-4 overflow-hidden p-4 sm:max-w-2xl sm:p-6',
|
||||
contentClassName,
|
||||
dialogContentMotionClassName
|
||||
)}
|
||||
initialFocus={initialFocus}
|
||||
showCloseButton={showCloseButton}
|
||||
style={
|
||||
{
|
||||
'--dialog-content-height': contentHeight,
|
||||
} as React.CSSProperties
|
||||
}
|
||||
>
|
||||
<DialogHeader
|
||||
className={cn('flex-shrink-0 text-start', headerClassName)}
|
||||
>
|
||||
<DialogTitle className={titleClassName}>{title}</DialogTitle>
|
||||
{description ? (
|
||||
<DialogDescription className={descriptionClassName}>
|
||||
{description}
|
||||
</DialogDescription>
|
||||
) : null}
|
||||
</DialogHeader>
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
'-mx-1 min-h-0 overflow-x-hidden overflow-y-auto overscroll-contain',
|
||||
'h-[var(--dialog-content-height)] max-h-[calc(100vh-14rem)]'
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'min-w-0 px-1 py-1',
|
||||
'[&_form]:overflow-x-visible',
|
||||
'[&_[data-slot=scroll-area-viewport]]:px-1 [&_[data-slot=scroll-area-viewport]]:py-1',
|
||||
bodyClassName
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{footer ? (
|
||||
<DialogFooter
|
||||
className={cn(
|
||||
'flex-shrink-0 gap-2 sm:-mx-6 sm:-mb-6 sm:justify-end sm:p-6',
|
||||
footerClassName
|
||||
)}
|
||||
>
|
||||
{footer}
|
||||
</DialogFooter>
|
||||
) : null}
|
||||
</DialogContent>
|
||||
</DialogRoot>
|
||||
)
|
||||
}
|
||||
+17
-26
@@ -25,15 +25,8 @@ import { useNotifications } from '@/hooks/use-notifications'
|
||||
import { useSystemConfig } from '@/hooks/use-system-config'
|
||||
import { useTopNavLinks } from '@/hooks/use-top-nav-links'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Skeleton } from '@/components/ui/skeleton'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { LanguageSwitcher } from '@/components/language-switcher'
|
||||
import { NotificationPopover } from '@/components/notification-popover'
|
||||
import { ProfileDropdown } from '@/components/profile-dropdown'
|
||||
@@ -427,28 +420,26 @@ export function PublicHeader(props: PublicHeaderProps) {
|
||||
closeAuthPrompt()
|
||||
}
|
||||
}}
|
||||
>
|
||||
<DialogContent className='sm:max-w-md'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Sign in required')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Please sign in to view {{module}}.', {
|
||||
module: authPromptTarget?.title || '',
|
||||
})}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className='bg-muted/40 text-muted-foreground rounded-lg px-3 py-2 text-sm'>
|
||||
{t('Redirecting to sign in in {{seconds}} seconds.', {
|
||||
seconds: authPromptSecondsLeft,
|
||||
})}
|
||||
</div>
|
||||
<DialogFooter>
|
||||
title={t('Sign in required')}
|
||||
description={t('Please sign in to view {{module}}.', {
|
||||
module: authPromptTarget?.title || '',
|
||||
})}
|
||||
contentClassName='sm:max-w-md'
|
||||
contentHeight='auto'
|
||||
footer={
|
||||
<>
|
||||
<Button variant='outline' onClick={closeAuthPrompt}>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button onClick={navigateToSignIn}>{t('Sign in now')}</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='bg-muted/40 text-muted-foreground rounded-lg px-3 py-2 text-sm'>
|
||||
{t('Redirecting to sign in in {{seconds}} seconds.', {
|
||||
seconds: authPromptSecondsLeft,
|
||||
})}
|
||||
</div>
|
||||
</Dialog>
|
||||
</>
|
||||
)
|
||||
|
||||
@@ -50,6 +50,7 @@ SectionPageLayoutBreadcrumb.displayName = 'SectionPageLayout.Breadcrumb'
|
||||
|
||||
export type SectionPageLayoutProps = {
|
||||
children: ReactNode
|
||||
fixedContent?: boolean
|
||||
}
|
||||
|
||||
export function SectionPageLayout(props: SectionPageLayoutProps) {
|
||||
@@ -95,7 +96,13 @@ export function SectionPageLayout(props: SectionPageLayoutProps) {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='min-h-0 flex-1 overflow-auto px-3 pt-1 pb-3 sm:px-4 sm:pt-1.5 sm:pb-4'>
|
||||
<div
|
||||
className={
|
||||
props.fixedContent
|
||||
? 'min-h-0 flex-1 overflow-hidden px-3 pt-1 pb-3 sm:px-4 sm:pt-1.5 sm:pb-4'
|
||||
: 'min-h-0 flex-1 overflow-auto px-3 pt-1 pb-3 sm:px-4 sm:pt-1.5 sm:pb-4'
|
||||
}
|
||||
>
|
||||
{content}
|
||||
</div>
|
||||
|
||||
|
||||
-1
@@ -46,7 +46,6 @@ export function LongText({
|
||||
|
||||
useEffect(() => {
|
||||
if (checkOverflow(ref.current)) {
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setIsOverflown(true)
|
||||
return
|
||||
}
|
||||
|
||||
+7
-3
@@ -42,14 +42,18 @@ interface MaskedValueDisplayProps {
|
||||
*/
|
||||
export function MaskedValueDisplay(props: MaskedValueDisplayProps) {
|
||||
return (
|
||||
<div className='flex items-center'>
|
||||
<div className='flex max-w-full min-w-0 items-center'>
|
||||
<Popover>
|
||||
<PopoverTrigger
|
||||
render={
|
||||
<Button variant='ghost' size='sm' className='h-7 font-mono' />
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
className='h-7 max-w-full min-w-0 justify-start truncate px-0 font-mono hover:bg-transparent aria-expanded:bg-transparent'
|
||||
/>
|
||||
}
|
||||
>
|
||||
{props.maskedValue}
|
||||
<span className='truncate'>{props.maskedValue}</span>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className='w-auto max-w-[min(90vw,28rem)]'
|
||||
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
Copyright (C) 2023-2026 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import { getLobeIcon } from '@/lib/lobe-icon'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { StatusBadge, type StatusBadgeProps } from './status-badge'
|
||||
|
||||
type ProviderBadgeProps = Omit<StatusBadgeProps, 'children' | 'label'> & {
|
||||
iconKey?: string | null
|
||||
iconSize?: number
|
||||
label: string
|
||||
}
|
||||
|
||||
export function ProviderBadge({
|
||||
className,
|
||||
iconKey,
|
||||
iconSize = 14,
|
||||
label,
|
||||
...badgeProps
|
||||
}: ProviderBadgeProps) {
|
||||
const icon = iconKey ? getLobeIcon(iconKey, iconSize) : null
|
||||
|
||||
return (
|
||||
<div className={cn('flex items-center gap-1.5', className)}>
|
||||
{icon}
|
||||
<StatusBadge label={label} autoColor={label} size='sm' {...badgeProps} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
+2
-1
@@ -103,7 +103,7 @@ export function StatusBadge({
|
||||
variant,
|
||||
size = 'sm',
|
||||
pulse = false,
|
||||
showDot = true,
|
||||
showDot = false,
|
||||
copyable = true,
|
||||
copyText,
|
||||
autoColor,
|
||||
@@ -130,6 +130,7 @@ export function StatusBadge({
|
||||
|
||||
return (
|
||||
<span
|
||||
data-slot='status-badge'
|
||||
className={cn(
|
||||
'inline-flex w-fit max-w-full shrink-0 items-center rounded-4xl font-medium tracking-normal whitespace-nowrap transition-colors',
|
||||
sizeMap[size ?? 'sm'],
|
||||
|
||||
+93
-51
@@ -17,6 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import * as React from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
import { Check, ChevronsUpDown } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/lib/utils'
|
||||
@@ -150,10 +151,101 @@ export function ComboboxInput({
|
||||
item?.scrollIntoView({ block: 'nearest' })
|
||||
}, [highlightedIndex])
|
||||
|
||||
const [dropdownPos, setDropdownPos] = React.useState<{
|
||||
top: number
|
||||
left: number
|
||||
width: number
|
||||
} | null>(null)
|
||||
|
||||
const updateDropdownPos = React.useCallback(() => {
|
||||
if (!containerRef.current) return
|
||||
const rect = containerRef.current.getBoundingClientRect()
|
||||
setDropdownPos({
|
||||
top: rect.bottom + 4,
|
||||
left: rect.left,
|
||||
width: rect.width,
|
||||
})
|
||||
}, [])
|
||||
|
||||
// Update dropdown position when open
|
||||
React.useEffect(() => {
|
||||
if (!open) {
|
||||
setDropdownPos(null)
|
||||
return
|
||||
}
|
||||
updateDropdownPos()
|
||||
const handleScroll = () => updateDropdownPos()
|
||||
window.addEventListener('scroll', handleScroll, true)
|
||||
window.addEventListener('resize', handleScroll)
|
||||
return () => {
|
||||
window.removeEventListener('scroll', handleScroll, true)
|
||||
window.removeEventListener('resize', handleScroll)
|
||||
}
|
||||
}, [open, updateDropdownPos])
|
||||
|
||||
const showDropdown =
|
||||
open &&
|
||||
(filteredOptions.length > 0 || (allowCustomValue && searchValue.trim()))
|
||||
|
||||
const dropdownContent = showDropdown && dropdownPos ? (
|
||||
<div
|
||||
className='bg-popover text-popover-foreground fixed z-[100] rounded-md border shadow-md'
|
||||
style={{
|
||||
top: dropdownPos.top,
|
||||
left: dropdownPos.left,
|
||||
width: dropdownPos.width,
|
||||
}}
|
||||
>
|
||||
{filteredOptions.length > 0 ? (
|
||||
<ul
|
||||
ref={listRef}
|
||||
role='listbox'
|
||||
className='max-h-[200px] overflow-y-auto p-1'
|
||||
>
|
||||
{filteredOptions.map((option, index) => (
|
||||
<li
|
||||
key={option.value}
|
||||
role='option'
|
||||
aria-selected={value === option.value}
|
||||
data-highlighted={index === highlightedIndex}
|
||||
className={cn(
|
||||
'relative flex cursor-pointer items-center gap-2 rounded-sm px-2 py-1.5 text-sm select-none',
|
||||
index === highlightedIndex &&
|
||||
'bg-accent text-accent-foreground',
|
||||
value === option.value && 'font-medium'
|
||||
)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault() // Prevent blur
|
||||
handleSelect(option.value)
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
'size-4 shrink-0',
|
||||
value === option.value ? 'opacity-100' : 'opacity-0'
|
||||
)}
|
||||
/>
|
||||
{option.icon && <span>{option.icon}</span>}
|
||||
<span className='truncate'>{option.label}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
) : (
|
||||
<div className='px-2 py-6 text-center text-sm'>
|
||||
{t(emptyText)}
|
||||
{allowCustomValue && searchValue.trim() && (
|
||||
<div className='text-muted-foreground mt-1 text-xs'>
|
||||
{t('Press Enter to use "{{value}}"', {
|
||||
value: searchValue.trim(),
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
) : null
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className='relative'>
|
||||
<Input
|
||||
@@ -184,57 +276,7 @@ export function ComboboxInput({
|
||||
/>
|
||||
<ChevronsUpDown className='pointer-events-none absolute top-1/2 right-3 size-4 shrink-0 -translate-y-1/2 opacity-50' />
|
||||
|
||||
{showDropdown && (
|
||||
<div className='bg-popover text-popover-foreground absolute top-full z-100 mt-1 w-full rounded-md border shadow-md'>
|
||||
{filteredOptions.length > 0 ? (
|
||||
<ul
|
||||
ref={listRef}
|
||||
role='listbox'
|
||||
className='max-h-[200px] overflow-y-auto p-1'
|
||||
>
|
||||
{filteredOptions.map((option, index) => (
|
||||
<li
|
||||
key={option.value}
|
||||
role='option'
|
||||
aria-selected={value === option.value}
|
||||
data-highlighted={index === highlightedIndex}
|
||||
className={cn(
|
||||
'relative flex cursor-pointer items-center gap-2 rounded-sm px-2 py-1.5 text-sm select-none',
|
||||
index === highlightedIndex &&
|
||||
'bg-accent text-accent-foreground',
|
||||
value === option.value && 'font-medium'
|
||||
)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault() // Prevent blur
|
||||
handleSelect(option.value)
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
'size-4 shrink-0',
|
||||
value === option.value ? 'opacity-100' : 'opacity-0'
|
||||
)}
|
||||
/>
|
||||
{option.icon && <span>{option.icon}</span>}
|
||||
<span className='truncate'>{option.label}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
) : (
|
||||
<div className='px-2 py-6 text-center text-sm'>
|
||||
{t(emptyText)}
|
||||
{allowCustomValue && searchValue.trim() && (
|
||||
<div className='text-muted-foreground mt-1 text-xs'>
|
||||
{t('Press Enter to use "{{value}}"', {
|
||||
value: searchValue.trim(),
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{dropdownContent && createPortal(dropdownContent, document.body)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
+1
-1
@@ -180,7 +180,7 @@ function ComboboxContent({
|
||||
data-slot='combobox-content'
|
||||
data-chips={!!anchor}
|
||||
className={cn(
|
||||
'dark group/combobox-content bg-popover text-popover-foreground ring-foreground/10 data-[side=bottom]:slide-in-from-top-2 data-[side=inline-end]:slide-in-from-left-2 data-[side=inline-start]:slide-in-from-right-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 *:data-[slot=input-group]:border-input/30 *:data-[slot=input-group]:bg-input/30 data-open:animate-in data-open:fade-in-0 data-open:zoom-in-95 data-closed:animate-out data-closed:fade-out-0 data-closed:zoom-out-95 relative max-h-(--available-height) w-(--anchor-width) max-w-(--available-width) min-w-[calc(var(--anchor-width)+--spacing(7))] origin-(--transform-origin) overflow-hidden rounded-lg shadow-md ring-1 duration-100 data-[chips=true]:min-w-(--anchor-width) *:data-[slot=input-group]:m-1 *:data-[slot=input-group]:mb-0 *:data-[slot=input-group]:h-8 *:data-[slot=input-group]:shadow-none',
|
||||
'group/combobox-content bg-popover text-popover-foreground ring-foreground/10 data-[side=bottom]:slide-in-from-top-2 data-[side=inline-end]:slide-in-from-left-2 data-[side=inline-start]:slide-in-from-right-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 *:data-[slot=input-group]:border-input/30 *:data-[slot=input-group]:bg-input/30 data-open:animate-in data-open:fade-in-0 data-open:zoom-in-95 data-closed:animate-out data-closed:fade-out-0 data-closed:zoom-out-95 relative max-h-(--available-height) w-(--anchor-width) max-w-(--available-width) min-w-[calc(var(--anchor-width)+--spacing(7))] origin-(--transform-origin) overflow-hidden rounded-lg shadow-md ring-1 duration-100 data-[chips=true]:min-w-(--anchor-width) *:data-[slot=input-group]:m-1 *:data-[slot=input-group]:mb-0 *:data-[slot=input-group]:h-8 *:data-[slot=input-group]:shadow-none',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
||||
+1
-1
@@ -103,7 +103,7 @@ function TableCell({ className, ...props }: React.ComponentProps<'td'>) {
|
||||
<td
|
||||
data-slot='table-cell'
|
||||
className={cn(
|
||||
'p-2 align-middle whitespace-nowrap [&:has([role=checkbox])]:pr-0',
|
||||
'p-2 align-middle whitespace-nowrap [&:has([role=checkbox])]:pr-0 [&>*:has(>[data-slot=status-badge]:first-child):first-child]:-ml-1.5 [&>[data-slot=status-badge]:first-child]:-ml-1.5',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
||||
@@ -31,6 +31,31 @@ import {
|
||||
} from '../lib/oauth'
|
||||
import type { SystemStatus, CustomOAuthProviderInfo } from '../types'
|
||||
|
||||
/**
|
||||
* Generate a random code verifier for PKCE
|
||||
*/
|
||||
function generateCodeVerifier(): string {
|
||||
const array = new Uint8Array(32)
|
||||
crypto.getRandomValues(array)
|
||||
return btoa(String.fromCharCode(...array))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '')
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate code challenge from code verifier using SHA-256
|
||||
*/
|
||||
async function generateCodeChallenge(verifier: string): Promise<string> {
|
||||
const encoder = new TextEncoder()
|
||||
const data = encoder.encode(verifier)
|
||||
const digest = await crypto.subtle.digest('SHA-256', data)
|
||||
return btoa(String.fromCharCode(...new Uint8Array(digest)))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '')
|
||||
}
|
||||
|
||||
type LogoutRequestConfig = AxiosRequestConfig & {
|
||||
skipErrorHandler?: boolean
|
||||
}
|
||||
@@ -211,6 +236,16 @@ export function useOAuthLogin(status: SystemStatus | null) {
|
||||
url.searchParams.set('scope', provider.scopes)
|
||||
}
|
||||
|
||||
// Add PKCE support if enabled
|
||||
if (provider.pkce_enabled) {
|
||||
const codeVerifier = generateCodeVerifier()
|
||||
const codeChallenge = await generateCodeChallenge(codeVerifier)
|
||||
// Store code_verifier in sessionStorage keyed by state
|
||||
sessionStorage.setItem(`pkce_verifier_${state}`, codeVerifier)
|
||||
url.searchParams.set('code_challenge', codeChallenge)
|
||||
url.searchParams.set('code_challenge_method', 'S256')
|
||||
}
|
||||
|
||||
window.open(url.toString(), '_self')
|
||||
} catch (_error) {
|
||||
toast.error(
|
||||
|
||||
Vendored
+106
-117
@@ -20,16 +20,9 @@ import { useMemo } from 'react'
|
||||
import { ShieldCheck, KeyRound, Loader2 } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import type {
|
||||
SecureVerificationState,
|
||||
VerificationMethod,
|
||||
@@ -91,122 +84,118 @@ export function SecureVerificationDialog({
|
||||
(activeMethod === '2fa' && (!state.code.trim() || state.code.length < 6))
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent
|
||||
className='top-[8vh] max-w-[calc(100%-1.5rem)] translate-y-0 gap-0 overflow-hidden border-none p-0 shadow-xl sm:top-1/2 sm:max-w-md sm:translate-y-[-50%] sm:rounded-xl'
|
||||
showCloseButton={!state.loading}
|
||||
>
|
||||
<div className='bg-background flex max-h-[calc(100dvh-2rem)] flex-col'>
|
||||
<DialogHeader className='border-b px-6 py-5 text-left'>
|
||||
<DialogTitle className='flex items-center gap-2 text-lg font-semibold'>
|
||||
<ShieldCheck className='text-primary h-5 w-5' />
|
||||
{title}
|
||||
</DialogTitle>
|
||||
<DialogDescription className='text-left'>
|
||||
{description}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={
|
||||
<>
|
||||
<ShieldCheck className='text-primary h-5 w-5' />
|
||||
{title}
|
||||
</>
|
||||
}
|
||||
description={description}
|
||||
contentClassName='top-[8vh] max-w-[calc(100%-1.5rem)] translate-y-0 overflow-hidden border-none shadow-xl sm:top-1/2 sm:max-w-md sm:translate-y-[-50%] sm:rounded-xl'
|
||||
headerClassName='border-b pb-4 text-left'
|
||||
titleClassName='flex items-center gap-2 text-lg font-semibold'
|
||||
descriptionClassName='text-left'
|
||||
contentHeight='auto'
|
||||
bodyClassName='px-1 py-1'
|
||||
showCloseButton={!state.loading}
|
||||
footerClassName='bg-muted/30 border-t px-6 py-4 sm:flex-row sm:justify-end'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
disabled={state.loading}
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
onClick={handleVerify}
|
||||
disabled={availableTabs.length === 0 || verifyDisabled}
|
||||
>
|
||||
{state.loading && <Loader2 className='h-4 w-4 animate-spin' />}
|
||||
{t('Verify')}
|
||||
</Button>
|
||||
</>
|
||||
}
|
||||
>
|
||||
{availableTabs.length === 0 ? (
|
||||
<div className='grid place-items-center gap-4 text-center'>
|
||||
<div className='bg-muted flex h-16 w-16 items-center justify-center rounded-2xl'>
|
||||
<ShieldCheck className='text-muted-foreground h-8 w-8' />
|
||||
</div>
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t(
|
||||
'Enable Two-factor Authentication or Passkey in your profile to unlock sensitive operations.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<Tabs
|
||||
value={activeMethod ?? availableTabs[0]}
|
||||
onValueChange={(value) => onMethodChange(value as VerificationMethod)}
|
||||
className='gap-4'
|
||||
>
|
||||
<TabsList>
|
||||
{methods.has2FA && (
|
||||
<TabsTrigger value='2fa'>{t('Authenticator code')}</TabsTrigger>
|
||||
)}
|
||||
{methods.hasPasskey && methods.passkeySupported && (
|
||||
<TabsTrigger value='passkey'>{t('Passkey')}</TabsTrigger>
|
||||
)}
|
||||
</TabsList>
|
||||
|
||||
<div className='flex-1 overflow-y-auto px-6 py-5'>
|
||||
{availableTabs.length === 0 ? (
|
||||
<div className='grid place-items-center gap-4 text-center'>
|
||||
<div className='bg-muted flex h-16 w-16 items-center justify-center rounded-2xl'>
|
||||
<ShieldCheck className='text-muted-foreground h-8 w-8' />
|
||||
</div>
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t(
|
||||
'Enable Two-factor Authentication or Passkey in your profile to unlock sensitive operations.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<Tabs
|
||||
value={activeMethod ?? availableTabs[0]}
|
||||
onValueChange={(value) =>
|
||||
onMethodChange(value as VerificationMethod)
|
||||
<TabsContent value='2fa' className='space-y-3'>
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t(
|
||||
'Enter the 6-digit Time-based One-Time Password or 8-character backup code from your authenticator app.'
|
||||
)}
|
||||
</p>
|
||||
<Input
|
||||
inputMode='numeric'
|
||||
maxLength={8}
|
||||
value={state.code}
|
||||
onChange={(event) => onCodeChange(event.target.value)}
|
||||
placeholder={t('Enter verification code')}
|
||||
disabled={state.loading}
|
||||
autoFocus={activeMethod === '2fa'}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === 'Enter' && !verifyDisabled) {
|
||||
event.preventDefault()
|
||||
handleVerify()
|
||||
}
|
||||
className='gap-4'
|
||||
>
|
||||
<TabsList>
|
||||
{methods.has2FA && (
|
||||
<TabsTrigger value='2fa'>
|
||||
{t('Authenticator code')}
|
||||
</TabsTrigger>
|
||||
)}
|
||||
{methods.hasPasskey && methods.passkeySupported && (
|
||||
<TabsTrigger value='passkey'>{t('Passkey')}</TabsTrigger>
|
||||
)}
|
||||
</TabsList>
|
||||
}}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value='2fa' className='space-y-3'>
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
<TabsContent value='passkey' className='space-y-4'>
|
||||
<div className='bg-muted/50 flex items-center justify-center rounded-lg p-4'>
|
||||
<div className='text-muted-foreground flex items-center gap-3'>
|
||||
<KeyRound className='text-primary h-6 w-6' />
|
||||
<div className='text-left text-sm'>
|
||||
<p className='text-foreground font-medium'>
|
||||
{t('Use your Passkey')}
|
||||
</p>
|
||||
<p>
|
||||
{t(
|
||||
'Enter the 6-digit Time-based One-Time Password or 8-character backup code from your authenticator app.'
|
||||
'We will prompt your device to confirm using biometrics or your hardware key.'
|
||||
)}
|
||||
</p>
|
||||
<Input
|
||||
inputMode='numeric'
|
||||
maxLength={8}
|
||||
value={state.code}
|
||||
onChange={(event) => onCodeChange(event.target.value)}
|
||||
placeholder={t('Enter verification code')}
|
||||
disabled={state.loading}
|
||||
autoFocus={activeMethod === '2fa'}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === 'Enter' && !verifyDisabled) {
|
||||
event.preventDefault()
|
||||
handleVerify()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value='passkey' className='space-y-4'>
|
||||
<div className='bg-muted/50 flex items-center justify-center rounded-lg p-4'>
|
||||
<div className='text-muted-foreground flex items-center gap-3'>
|
||||
<KeyRound className='text-primary h-6 w-6' />
|
||||
<div className='text-left text-sm'>
|
||||
<p className='text-foreground font-medium'>
|
||||
{t('Use your Passkey')}
|
||||
</p>
|
||||
<p>
|
||||
{t(
|
||||
'We will prompt your device to confirm using biometrics or your hardware key.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!methods.passkeySupported && (
|
||||
<p className='text-destructive text-sm'>
|
||||
{t('This device does not support Passkey verification.')}
|
||||
</p>
|
||||
)}
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!methods.passkeySupported && (
|
||||
<p className='text-destructive text-sm'>
|
||||
{t('This device does not support Passkey verification.')}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<DialogFooter className='bg-muted/30 border-t px-6 py-4 sm:flex-row sm:justify-end'>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
disabled={state.loading}
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
onClick={handleVerify}
|
||||
disabled={availableTabs.length === 0 || verifyDisabled}
|
||||
>
|
||||
{state.loading && <Loader2 className='h-4 w-4 animate-spin' />}
|
||||
{t('Verify')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
)}
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -32,14 +32,6 @@ import {
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useStatus } from '@/hooks/use-status'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
@@ -50,6 +42,7 @@ import {
|
||||
} from '@/components/ui/form'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { PasswordInput } from '@/components/password-input'
|
||||
import { Turnstile } from '@/components/turnstile'
|
||||
import { login, wechatLoginByCode } from '@/features/auth/api'
|
||||
@@ -414,43 +407,16 @@ export function UserAuthForm({
|
||||
<Dialog
|
||||
open={isWeChatDialogOpen}
|
||||
onOpenChange={handleWeChatDialogChange}
|
||||
>
|
||||
<DialogContent className='max-w-sm'>
|
||||
<DialogHeader className='text-left'>
|
||||
<DialogTitle>{t('WeChat sign in')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t(
|
||||
'Scan the QR code to follow the official account and reply with “验证码” to receive your verification code.'
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{wechatQrCodeUrl ? (
|
||||
<div className='flex justify-center'>
|
||||
<img
|
||||
src={wechatQrCodeUrl}
|
||||
alt={t('WeChat login QR code')}
|
||||
className='h-40 w-40 rounded-md border object-contain'
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t('QR code is not configured. Please contact support.')}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='wechat-code'>{t('Verification code')}</Label>
|
||||
<Input
|
||||
id='wechat-code'
|
||||
placeholder={t('Enter the verification code')}
|
||||
value={wechatCode}
|
||||
onChange={(event) => setWeChatCode(event.target.value)}
|
||||
autoComplete='one-time-code'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
title={t('WeChat sign in')}
|
||||
description={t(
|
||||
'Scan the QR code to follow the official account and reply with “验证码” to receive your verification code.'
|
||||
)}
|
||||
contentClassName='max-w-sm'
|
||||
headerClassName='text-left'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
@@ -474,8 +440,32 @@ export function UserAuthForm({
|
||||
) : null}
|
||||
{t('Confirm')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
{wechatQrCodeUrl ? (
|
||||
<div className='flex justify-center'>
|
||||
<img
|
||||
src={wechatQrCodeUrl}
|
||||
alt={t('WeChat login QR code')}
|
||||
className='h-40 w-40 rounded-md border object-contain'
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t('QR code is not configured. Please contact support.')}
|
||||
</p>
|
||||
)}
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='wechat-code'>{t('Verification code')}</Label>
|
||||
<Input
|
||||
id='wechat-code'
|
||||
placeholder={t('Enter the verification code')}
|
||||
value={wechatCode}
|
||||
onChange={(event) => setWeChatCode(event.target.value)}
|
||||
autoComplete='one-time-code'
|
||||
/>
|
||||
</div>
|
||||
</Dialog>
|
||||
)}
|
||||
</Form>
|
||||
|
||||
@@ -26,14 +26,6 @@ import { toast } from 'sonner'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useStatus } from '@/hooks/use-status'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
@@ -44,6 +36,7 @@ import {
|
||||
} from '@/components/ui/form'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { PasswordInput } from '@/components/password-input'
|
||||
import { Turnstile } from '@/components/turnstile'
|
||||
import { register, wechatLoginByCode } from '@/features/auth/api'
|
||||
@@ -387,43 +380,16 @@ export function SignUpForm({
|
||||
<Dialog
|
||||
open={isWeChatDialogOpen}
|
||||
onOpenChange={handleWeChatDialogChange}
|
||||
>
|
||||
<DialogContent className='max-w-sm'>
|
||||
<DialogHeader className='text-left'>
|
||||
<DialogTitle>{t('WeChat sign in')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t(
|
||||
'Scan the QR code to follow the official account and reply with “验证码” to receive your verification code.'
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{wechatQrCodeUrl ? (
|
||||
<div className='flex justify-center'>
|
||||
<img
|
||||
src={wechatQrCodeUrl}
|
||||
alt={t('WeChat login QR code')}
|
||||
className='h-40 w-40 rounded-md border object-contain'
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t('QR code is not configured. Please contact support.')}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='wechat-code'>{t('Verification code')}</Label>
|
||||
<Input
|
||||
id='wechat-code'
|
||||
placeholder={t('Enter the verification code')}
|
||||
value={wechatCode}
|
||||
onChange={(event) => setWeChatCode(event.target.value)}
|
||||
autoComplete='one-time-code'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
title={t('WeChat sign in')}
|
||||
description={t(
|
||||
'Scan the QR code to follow the official account and reply with “验证码” to receive your verification code.'
|
||||
)}
|
||||
contentClassName='max-w-sm'
|
||||
headerClassName='text-left'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
@@ -447,8 +413,32 @@ export function SignUpForm({
|
||||
) : null}
|
||||
{t('Confirm')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
{wechatQrCodeUrl ? (
|
||||
<div className='flex justify-center'>
|
||||
<img
|
||||
src={wechatQrCodeUrl}
|
||||
alt={t('WeChat login QR code')}
|
||||
className='h-40 w-40 rounded-md border object-contain'
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<p className='text-muted-foreground text-sm'>
|
||||
{t('QR code is not configured. Please contact support.')}
|
||||
</p>
|
||||
)}
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='wechat-code'>{t('Verification code')}</Label>
|
||||
<Input
|
||||
id='wechat-code'
|
||||
placeholder={t('Enter the verification code')}
|
||||
value={wechatCode}
|
||||
onChange={(event) => setWeChatCode(event.target.value)}
|
||||
autoComplete='one-time-code'
|
||||
/>
|
||||
</div>
|
||||
</Dialog>
|
||||
)}
|
||||
</Form>
|
||||
|
||||
+1
@@ -195,6 +195,7 @@ export interface CustomOAuthProviderInfo {
|
||||
client_id: string
|
||||
authorization_endpoint: string
|
||||
scopes: string
|
||||
pkce_enabled: boolean
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -35,7 +35,6 @@ import {
|
||||
formatTimestampToDate,
|
||||
formatQuota as formatQuotaValue,
|
||||
} from '@/lib/format'
|
||||
import { getLobeIcon } from '@/lib/lobe-icon'
|
||||
import { truncateText } from '@/lib/utils'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
@@ -46,8 +45,9 @@ import {
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { ConfirmDialog } from '@/components/confirm-dialog'
|
||||
import { DataTableColumnHeader } from '@/components/data-table/column-header'
|
||||
import { DataTableColumnHeader } from '@/components/data-table'
|
||||
import { GroupBadge } from '@/components/group-badge'
|
||||
import { ProviderBadge } from '@/components/provider-badge'
|
||||
import { StatusBadge, StatusBadgeList } from '@/components/status-badge'
|
||||
import { TableId } from '@/components/table-id'
|
||||
import { TruncatedText } from '@/components/truncated-text'
|
||||
@@ -623,7 +623,6 @@ export function useChannelsColumns(): ColumnDef<Channel>[] {
|
||||
const typeNameKey = getChannelTypeLabel(type)
|
||||
const typeName = t(typeNameKey)
|
||||
const iconName = getChannelTypeIcon(type)
|
||||
const icon = getLobeIcon(`${iconName}.Color`, 14)
|
||||
const channel = row.original as Channel
|
||||
const isMultiKey = isMultiKeyChannel(channel)
|
||||
const multiKeyMode = channel.channel_info?.multi_key_mode ?? 'random'
|
||||
@@ -657,16 +656,12 @@ export function useChannelsColumns(): ColumnDef<Channel>[] {
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<StatusBadge
|
||||
autoColor={typeName}
|
||||
size='sm'
|
||||
<ProviderBadge
|
||||
iconKey={iconName}
|
||||
label={typeName}
|
||||
copyable={false}
|
||||
showDot={false}
|
||||
className='gap-1 pl-1'
|
||||
>
|
||||
{icon}
|
||||
<span className='truncate'>{typeName}</span>
|
||||
</StatusBadge>
|
||||
/>
|
||||
{isIonet && (
|
||||
<TooltipProvider delay={100}>
|
||||
<Tooltip>
|
||||
|
||||
+38
-62
@@ -16,20 +16,15 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import { useState, useMemo, useEffect } from 'react'
|
||||
import { useState, useMemo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { getRouteApi } from '@tanstack/react-router'
|
||||
import {
|
||||
getCoreRowModel,
|
||||
useReactTable,
|
||||
getExpandedRowModel,
|
||||
type OnChangeFn,
|
||||
type SortingState,
|
||||
type VisibilityState,
|
||||
type ExpandedState,
|
||||
type Row,
|
||||
} from '@tanstack/react-table'
|
||||
import { useDebounce, useMediaQuery } from '@/hooks'
|
||||
import { useMediaQuery } from '@/hooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getLobeIcon } from '@/lib/lobe-icon'
|
||||
import { useTableUrlState } from '@/hooks/use-table-url-state'
|
||||
@@ -38,6 +33,8 @@ import {
|
||||
DISABLED_ROW_DESKTOP,
|
||||
DISABLED_ROW_MOBILE,
|
||||
DataTablePage,
|
||||
useDebouncedColumnFilter,
|
||||
useDataTable,
|
||||
} from '@/components/data-table'
|
||||
import { getChannels, searchChannels, getGroups } from '../api'
|
||||
import {
|
||||
@@ -81,12 +78,6 @@ export function ChannelsTable() {
|
||||
|
||||
// Table state
|
||||
const [sorting, setSorting] = useState<SortingState>([])
|
||||
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>({
|
||||
models: false,
|
||||
tag: false,
|
||||
})
|
||||
const [rowSelection, setRowSelection] = useState({})
|
||||
const [expanded, setExpanded] = useState<ExpandedState>({})
|
||||
|
||||
// URL state management
|
||||
const {
|
||||
@@ -116,35 +107,24 @@ export function ChannelsTable() {
|
||||
// Extract filters from column filters
|
||||
const statusFilter =
|
||||
(columnFilters.find((f) => f.id === 'status')?.value as string[]) || []
|
||||
const typeFilter =
|
||||
(columnFilters.find((f) => f.id === 'type')?.value as string[]) || []
|
||||
const typeFilter = useMemo(
|
||||
() => (columnFilters.find((f) => f.id === 'type')?.value as string[]) || [],
|
||||
[columnFilters]
|
||||
)
|
||||
const groupFilter =
|
||||
(columnFilters.find((f) => f.id === 'group')?.value as string[]) || []
|
||||
const modelFilterFromUrl =
|
||||
(columnFilters.find((f) => f.id === 'model')?.value as string) || ''
|
||||
|
||||
// Local state for immediate input feedback
|
||||
const [modelFilterInput, setModelFilterInput] = useState(modelFilterFromUrl)
|
||||
const debouncedModelFilter = useDebounce(modelFilterInput, 500)
|
||||
|
||||
// Sync local input with URL when URL changes (e.g., from back/forward navigation)
|
||||
useEffect(() => {
|
||||
setModelFilterInput(modelFilterFromUrl)
|
||||
}, [modelFilterFromUrl])
|
||||
|
||||
// Update URL when debounced value changes
|
||||
useEffect(() => {
|
||||
if (debouncedModelFilter !== modelFilterFromUrl) {
|
||||
onColumnFiltersChange((prev) => {
|
||||
const filtered = prev.filter((f) => f.id !== 'model')
|
||||
return debouncedModelFilter
|
||||
? [...filtered, { id: 'model', value: debouncedModelFilter }]
|
||||
: filtered
|
||||
})
|
||||
}
|
||||
}, [debouncedModelFilter, modelFilterFromUrl, onColumnFiltersChange])
|
||||
|
||||
const modelFilter = modelFilterFromUrl
|
||||
const {
|
||||
value: modelFilter,
|
||||
inputValue: modelFilterInput,
|
||||
onChange: onModelFilterInputChange,
|
||||
onCompositionStart: onModelFilterCompositionStart,
|
||||
onCompositionEnd: onModelFilterCompositionEnd,
|
||||
resetInput: resetModelFilterInput,
|
||||
} = useDebouncedColumnFilter({
|
||||
columnFilters,
|
||||
columnId: 'model',
|
||||
onColumnFiltersChange,
|
||||
})
|
||||
|
||||
// Determine whether to use search or regular list API
|
||||
const shouldSearch = Boolean(globalFilter?.trim() || modelFilter.trim())
|
||||
@@ -279,41 +259,31 @@ export function ChannelsTable() {
|
||||
const columns = useChannelsColumns()
|
||||
|
||||
// React Table instance
|
||||
const table = useReactTable({
|
||||
const { table } = useDataTable({
|
||||
data: channels,
|
||||
columns,
|
||||
pageCount: Math.ceil(totalCount / pagination.pageSize),
|
||||
state: {
|
||||
sorting,
|
||||
columnFilters,
|
||||
columnVisibility,
|
||||
rowSelection,
|
||||
pagination,
|
||||
expanded,
|
||||
globalFilter,
|
||||
totalCount,
|
||||
sorting,
|
||||
initialColumnVisibility: {
|
||||
models: false,
|
||||
tag: false,
|
||||
},
|
||||
columnFilters,
|
||||
pagination,
|
||||
globalFilter,
|
||||
enableRowSelection: (row: Row<Channel>) => !isTagAggregateRow(row.original),
|
||||
onRowSelectionChange: setRowSelection,
|
||||
onSortingChange: handleSortingChange,
|
||||
onColumnFiltersChange,
|
||||
onColumnVisibilityChange: setColumnVisibility,
|
||||
onPaginationChange,
|
||||
onExpandedChange: setExpanded,
|
||||
onGlobalFilterChange,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getExpandedRowModel: getExpandedRowModel(),
|
||||
getSubRows: (row: Channel & { children?: Channel[] }) => row.children,
|
||||
manualPagination: true,
|
||||
manualSorting: true,
|
||||
manualFiltering: true,
|
||||
withExpandedRowModel: true,
|
||||
ensurePageInRange,
|
||||
})
|
||||
|
||||
// Ensure page is in range when total count changes
|
||||
const pageCount = table.getPageCount()
|
||||
useEffect(() => {
|
||||
ensurePageInRange(pageCount)
|
||||
}, [pageCount, ensurePageInRange])
|
||||
|
||||
// Prepare filter options from existing channel types only.
|
||||
const typeFilterOptions = useMemo(() => {
|
||||
const counts = typeCounts || {}
|
||||
@@ -385,11 +355,17 @@ export function ChannelsTable() {
|
||||
applyHeaderSize
|
||||
toolbarProps={{
|
||||
searchPlaceholder: t('Filter by name, ID, or key...'),
|
||||
searchDebounceMs: 500,
|
||||
onReset: () => {
|
||||
resetModelFilterInput()
|
||||
},
|
||||
additionalSearch: (
|
||||
<Input
|
||||
placeholder={t('Filter by model...')}
|
||||
value={modelFilterInput}
|
||||
onChange={(e) => setModelFilterInput(e.target.value)}
|
||||
onChange={onModelFilterInputChange}
|
||||
onCompositionStart={onModelFilterCompositionStart}
|
||||
onCompositionEnd={onModelFilterCompositionEnd}
|
||||
className='w-full sm:w-[150px] lg:w-[180px]'
|
||||
/>
|
||||
),
|
||||
|
||||
@@ -22,14 +22,6 @@ import { type Table } from '@tanstack/react-table'
|
||||
import { Power, PowerOff, Tag, Trash2 } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import {
|
||||
@@ -38,6 +30,7 @@ import {
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { DataTableBulkActions as BulkActionsToolbar } from '@/components/data-table'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import {
|
||||
handleBatchDelete,
|
||||
handleBatchDisable,
|
||||
@@ -188,29 +181,21 @@ export function DataTableBulkActions<TData>({
|
||||
</BulkActionsToolbar>
|
||||
|
||||
{/* Set Tag Dialog */}
|
||||
<Dialog open={showTagDialog} onOpenChange={setShowTagDialog}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Set Tag')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Set a tag for')} {selectedIds.length}{' '}
|
||||
{t('selected channel(s). Leave empty to remove tag.')}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='grid gap-4 py-4'>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='tag'>{t('Tag')}</Label>
|
||||
<Input
|
||||
id='tag'
|
||||
placeholder={t('Enter tag name (optional)')}
|
||||
value={tagValue}
|
||||
onChange={(e) => setTagValue(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={showTagDialog}
|
||||
onOpenChange={setShowTagDialog}
|
||||
title={t('Set Tag')}
|
||||
description={
|
||||
<>
|
||||
{t('Set a tag for')}
|
||||
{selectedIds.length}{' '}
|
||||
{t('selected channel(s). Leave empty to remove tag.')}
|
||||
</>
|
||||
}
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={() => {
|
||||
@@ -221,22 +206,37 @@ export function DataTableBulkActions<TData>({
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button onClick={handleSetTag}>{t('Set Tag')}</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='grid gap-4 py-4'>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='tag'>{t('Tag')}</Label>
|
||||
<Input
|
||||
id='tag'
|
||||
placeholder={t('Enter tag name (optional)')}
|
||||
value={tagValue}
|
||||
onChange={(e) => setTagValue(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
|
||||
{/* Delete Confirmation Dialog */}
|
||||
<Dialog open={showDeleteConfirm} onOpenChange={setShowDeleteConfirm}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Delete Channels?')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Are you sure you want to delete')} {selectedIds.length}{' '}
|
||||
{t('channel(s)? This action cannot be undone.')}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={showDeleteConfirm}
|
||||
onOpenChange={setShowDeleteConfirm}
|
||||
title={t('Delete Channels?')}
|
||||
description={
|
||||
<>
|
||||
{t('Are you sure you want to delete')}
|
||||
{selectedIds.length}{' '}
|
||||
{t('channel(s)? This action cannot be undone.')}
|
||||
</>
|
||||
}
|
||||
contentHeight='auto'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={() => setShowDeleteConfirm(false)}
|
||||
@@ -246,8 +246,10 @@ export function DataTableBulkActions<TData>({
|
||||
<Button variant='destructive' onClick={handleDeleteAll}>
|
||||
{t('Delete')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
{' '}
|
||||
</Dialog>
|
||||
</>
|
||||
)
|
||||
|
||||
+47
-52
@@ -24,14 +24,7 @@ import { toast } from 'sonner'
|
||||
import { formatCurrencyFromUSD } from '@/lib/currency'
|
||||
import { formatTimestampToDate } from '@/lib/format'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { getCodexUsage, updateChannelBalance } from '../../api'
|
||||
import { channelsQueryKeys } from '../../lib'
|
||||
import { useChannels } from '../channels-provider'
|
||||
@@ -161,53 +154,55 @@ export function BalanceQueryDialog({
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Query Balance')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Update balance for:')} <strong>{currentRow.name}</strong>
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='space-y-4 py-4'>
|
||||
{/* Current Balance Display */}
|
||||
<div className='bg-muted/50 rounded-lg border p-4'>
|
||||
<div className='text-muted-foreground mb-2 flex items-center gap-2 text-sm'>
|
||||
<DollarSign className='h-4 w-4' />
|
||||
<span>{t('Current Balance')}</span>
|
||||
</div>
|
||||
<div className='text-2xl font-bold'>
|
||||
{balance !== null
|
||||
? formatBalance(balance)
|
||||
: formatBalance(currentRow.balance)}
|
||||
</div>
|
||||
<div className='text-muted-foreground mt-2 text-xs'>
|
||||
{t('Last updated:')}{' '}
|
||||
{formatDate(
|
||||
balanceUpdatedTime ?? currentRow.balance_updated_time
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Balance Update Button */}
|
||||
<Button
|
||||
className='w-full'
|
||||
onClick={handleQueryBalance}
|
||||
disabled={isQuerying}
|
||||
>
|
||||
{isQuerying && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{!isQuerying && <RefreshCw className='mr-2 h-4 w-4' />}
|
||||
{isQuerying ? t('Querying...') : t('Update Balance')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={handleClose}
|
||||
title={t('Query Balance')}
|
||||
description={
|
||||
<>
|
||||
{t('Update balance for:')}
|
||||
<strong>{currentRow.name}</strong>
|
||||
</>
|
||||
}
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button variant='outline' onClick={handleClose} disabled={isQuerying}>
|
||||
{t('Close')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='space-y-4 py-4'>
|
||||
{/* Current Balance Display */}
|
||||
<div className='bg-muted/50 rounded-lg border p-4'>
|
||||
<div className='text-muted-foreground mb-2 flex items-center gap-2 text-sm'>
|
||||
<DollarSign className='h-4 w-4' />
|
||||
<span>{t('Current Balance')}</span>
|
||||
</div>
|
||||
<div className='text-2xl font-bold'>
|
||||
{balance !== null
|
||||
? formatBalance(balance)
|
||||
: formatBalance(currentRow.balance)}
|
||||
</div>
|
||||
<div className='text-muted-foreground mt-2 text-xs'>
|
||||
{t('Last updated:')}{' '}
|
||||
{formatDate(balanceUpdatedTime ?? currentRow.balance_updated_time)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Balance Update Button */}
|
||||
<Button
|
||||
className='w-full'
|
||||
onClick={handleQueryBalance}
|
||||
disabled={isQuerying}
|
||||
>
|
||||
{isQuerying && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{!isQuerying && <RefreshCw className='mr-2 h-4 w-4' />}
|
||||
{isQuerying ? t('Querying...') : t('Update Balance')}
|
||||
</Button>
|
||||
</div>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+150
-200
@@ -21,10 +21,6 @@ import {
|
||||
type ColumnDef,
|
||||
type RowSelectionState,
|
||||
type Table as TanStackTable,
|
||||
flexRender,
|
||||
getCoreRowModel,
|
||||
getPaginationRowModel,
|
||||
useReactTable,
|
||||
} from '@tanstack/react-table'
|
||||
import { Check, Copy, Info, Loader2, Settings } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -33,14 +29,6 @@ import { useCopyToClipboard } from '@/hooks/use-copy-to-clipboard'
|
||||
import { useIsMobile } from '@/hooks/use-mobile'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import {
|
||||
@@ -60,21 +48,18 @@ import {
|
||||
SheetTitle,
|
||||
} from '@/components/ui/sheet'
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from '@/components/ui/table'
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { DataTableBulkActions as BulkActionsToolbar } from '@/components/data-table'
|
||||
import { DataTablePagination } from '@/components/data-table/pagination'
|
||||
import {
|
||||
DataTableBulkActions as BulkActionsToolbar,
|
||||
DataTablePagination,
|
||||
DataTableView,
|
||||
useDataTable,
|
||||
} from '@/components/data-table'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import {
|
||||
sideDrawerContentClassName,
|
||||
sideDrawerFooterClassName,
|
||||
@@ -207,7 +192,7 @@ function getTestTableColumnClass(columnId: string) {
|
||||
case 'status':
|
||||
return 'w-70 min-w-70 max-w-70 whitespace-normal'
|
||||
case 'actions':
|
||||
return 'bg-popover sticky right-0 z-20 w-24 min-w-24 border-l shadow-[-8px_0_8px_-8px_rgb(0_0_0_/_0.2)] whitespace-nowrap sm:w-28 sm:min-w-28'
|
||||
return 'bg-popover w-24 min-w-24 whitespace-nowrap sm:w-28 sm:min-w-28'
|
||||
default:
|
||||
return undefined
|
||||
}
|
||||
@@ -234,6 +219,14 @@ export function ChannelTestDialog({
|
||||
pageIndex: 0,
|
||||
pageSize: 10,
|
||||
})
|
||||
const endpointSelectItems = useMemo(
|
||||
() =>
|
||||
endpointTypeOptions.map((option) => ({
|
||||
value: option.value,
|
||||
label: t(option.label),
|
||||
})),
|
||||
[t]
|
||||
)
|
||||
|
||||
const resetState = useCallback(() => {
|
||||
setEndpointType('auto')
|
||||
@@ -509,18 +502,17 @@ export function ChannelTestDialog({
|
||||
]
|
||||
)
|
||||
|
||||
const table = useReactTable({
|
||||
const { table } = useDataTable({
|
||||
data: tableData,
|
||||
columns,
|
||||
state: {
|
||||
rowSelection,
|
||||
pagination,
|
||||
},
|
||||
rowSelection,
|
||||
pagination,
|
||||
enableRowSelection: true,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getPaginationRowModel: getPaginationRowModel(),
|
||||
onRowSelectionChange: setRowSelection,
|
||||
onPaginationChange: setPagination,
|
||||
withFilteredRowModel: false,
|
||||
withSortedRowModel: false,
|
||||
withFacetedRowModel: false,
|
||||
})
|
||||
|
||||
if (!currentRow) {
|
||||
@@ -529,179 +521,137 @@ export function ChannelTestDialog({
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent className='max-h-[90vh] overflow-hidden sm:max-w-3xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Test Channel Connection')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Test connectivity for:')} <strong>{currentRow.name}</strong>
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='max-h-[78vh] space-y-4 overflow-y-auto py-4 pr-1'>
|
||||
<div className='grid gap-4 md:grid-cols-2'>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='endpoint-type'>{t('Endpoint Type')}</Label>
|
||||
<Select
|
||||
items={[
|
||||
...endpointTypeOptions.map((option) => {
|
||||
const itemValue = option.value
|
||||
return { value: itemValue, label: t(option.label) }
|
||||
}),
|
||||
]}
|
||||
value={endpointType}
|
||||
onValueChange={(v) => v !== null && setEndpointType(v)}
|
||||
>
|
||||
<SelectTrigger id='endpoint-type'>
|
||||
<SelectValue placeholder={t('Auto detect (default)')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
{endpointTypeOptions.map((option) => {
|
||||
const itemValue = option.value
|
||||
return (
|
||||
<SelectItem key={itemValue} value={itemValue}>
|
||||
{t(option.label)}
|
||||
</SelectItem>
|
||||
)
|
||||
})}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Override the endpoint used for testing. Leave empty to auto detect.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='stream-toggle'>{t('Stream Mode')}</Label>
|
||||
<div className='flex items-center gap-2'>
|
||||
<Switch
|
||||
id='stream-toggle'
|
||||
checked={isStreamTest}
|
||||
onCheckedChange={setIsStreamTest}
|
||||
disabled={streamDisabled}
|
||||
/>
|
||||
<span className='text-sm'>
|
||||
{isStreamTest ? t('Enabled') : t('Disabled')}
|
||||
</span>
|
||||
</div>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Enable streaming mode for the test request.')}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='space-y-3 max-sm:has-[div[role="toolbar"]]:pb-16'>
|
||||
<div className='flex flex-col gap-2 sm:flex-row sm:items-center sm:justify-between'>
|
||||
<div>
|
||||
<p className='text-sm font-medium'>{t('Channel models')}</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Select models to run batch tests.')}
|
||||
</p>
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('Filter models...')}
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className='sm:w-64'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className='space-y-3'>
|
||||
<div
|
||||
className='overflow-hidden rounded-md border'
|
||||
role='region'
|
||||
aria-label={t('Channel models')}
|
||||
>
|
||||
<div className='max-h-90 overflow-auto **:data-[slot=table-container]:overflow-visible'>
|
||||
<Table className='w-max min-w-full table-auto'>
|
||||
<colgroup>
|
||||
<col className='w-10 min-w-10' />
|
||||
<col className='w-auto' />
|
||||
<col className='w-70' />
|
||||
<col className='w-24 sm:w-28' />
|
||||
</colgroup>
|
||||
<TableHeader>
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => (
|
||||
<TableHead
|
||||
key={header.id}
|
||||
className={getTestTableColumnClass(
|
||||
header.column.id
|
||||
)}
|
||||
>
|
||||
{header.isPlaceholder
|
||||
? null
|
||||
: flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext()
|
||||
)}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{table.getRowModel().rows.length ? (
|
||||
table.getRowModel().rows.map((row) => (
|
||||
<TableRow
|
||||
key={row.id}
|
||||
data-state={
|
||||
row.getIsSelected() ? 'selected' : undefined
|
||||
}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell
|
||||
key={cell.id}
|
||||
className={getTestTableColumnClass(
|
||||
cell.column.id
|
||||
)}
|
||||
>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext()
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={table.getVisibleLeafColumns().length}
|
||||
className='text-muted-foreground h-16 text-center text-sm'
|
||||
>
|
||||
{models.length
|
||||
? 'No models matched your search.'
|
||||
: 'This channel has no configured models.'}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DataTablePagination table={table} />
|
||||
</div>
|
||||
|
||||
<TestModelsBulkActions
|
||||
table={table}
|
||||
disabled={isAnyTesting}
|
||||
onTestSelected={handleBatchTest}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={handleClose}
|
||||
title={t('Test Channel Connection')}
|
||||
description={
|
||||
<>
|
||||
{t('Test connectivity for:')}
|
||||
<strong>{currentRow.name}</strong>
|
||||
</>
|
||||
}
|
||||
contentClassName='max-h-[90vh] overflow-hidden sm:max-w-3xl'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button variant='outline' onClick={handleClose}>
|
||||
{t('Close')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='max-h-[78vh] space-y-4 overflow-y-auto py-4 pr-1'>
|
||||
<div className='grid gap-4 md:grid-cols-2'>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='endpoint-type'>{t('Endpoint Type')}</Label>
|
||||
<Select
|
||||
items={endpointSelectItems}
|
||||
value={endpointType}
|
||||
onValueChange={(v) => v !== null && setEndpointType(v)}
|
||||
>
|
||||
<SelectTrigger id='endpoint-type'>
|
||||
<SelectValue placeholder={t('Auto detect (default)')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
{endpointSelectItems.map((option) => (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Override the endpoint used for testing. Leave empty to auto detect.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<div className='grid gap-2'>
|
||||
<Label htmlFor='stream-toggle'>{t('Stream Mode')}</Label>
|
||||
<div className='flex items-center gap-2'>
|
||||
<Switch
|
||||
id='stream-toggle'
|
||||
checked={isStreamTest}
|
||||
onCheckedChange={setIsStreamTest}
|
||||
disabled={streamDisabled}
|
||||
/>
|
||||
<span className='text-sm'>
|
||||
{isStreamTest ? t('Enabled') : t('Disabled')}
|
||||
</span>
|
||||
</div>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Enable streaming mode for the test request.')}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='space-y-3 max-sm:has-[div[role="toolbar"]]:pb-16'>
|
||||
<div className='flex flex-col gap-2 sm:flex-row sm:items-center sm:justify-between'>
|
||||
<div>
|
||||
<p className='text-sm font-medium'>{t('Channel models')}</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Select models to run batch tests.')}
|
||||
</p>
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('Filter models...')}
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className='sm:w-64'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className='space-y-3'>
|
||||
<DataTableView
|
||||
table={table}
|
||||
containerClassName='rounded-md'
|
||||
containerProps={{
|
||||
role: 'region',
|
||||
'aria-label': t('Channel models'),
|
||||
}}
|
||||
tableContainerClassName='max-h-90 overflow-auto **:data-[slot=table-container]:overflow-visible'
|
||||
tableClassName='w-max min-w-full table-auto'
|
||||
pinnedColumns={[
|
||||
{
|
||||
columnId: 'actions',
|
||||
side: 'right',
|
||||
className: 'w-24 min-w-24 sm:w-28 sm:min-w-28',
|
||||
cellClassName: 'bg-popover',
|
||||
},
|
||||
]}
|
||||
colgroup={
|
||||
<colgroup>
|
||||
<col className='w-10 min-w-10' />
|
||||
<col className='w-auto' />
|
||||
<col className='w-70' />
|
||||
<col className='w-24 sm:w-28' />
|
||||
</colgroup>
|
||||
}
|
||||
getColumnClassName={(columnId) =>
|
||||
getTestTableColumnClass(columnId)
|
||||
}
|
||||
emptyContent={
|
||||
models.length
|
||||
? t('No models matched your search.')
|
||||
: t('This channel has no configured models.')
|
||||
}
|
||||
emptyCellClassName='text-muted-foreground h-16 text-center text-sm'
|
||||
/>
|
||||
|
||||
<DataTablePagination table={table} />
|
||||
</div>
|
||||
|
||||
<TestModelsBulkActions
|
||||
table={table}
|
||||
disabled={isAnyTesting}
|
||||
onTestSelected={handleBatchTest}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
<FailureDetailsSheet
|
||||
details={failureDetails}
|
||||
|
||||
+75
-82
@@ -24,15 +24,8 @@ import { tryPrettyJson } from '@/lib/utils'
|
||||
import { useCopyToClipboard } from '@/hooks/use-copy-to-clipboard'
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { completeCodexOAuth, startCodexOAuth } from '../../api'
|
||||
|
||||
type CodexOAuthDialogProps = {
|
||||
@@ -129,78 +122,18 @@ export function CodexOAuthDialog({
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className='sm:max-w-2xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Codex Authorization')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t(
|
||||
'Generate a Codex OAuth credential and paste it into the channel key field.'
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='space-y-4'>
|
||||
<Alert>
|
||||
<AlertDescription>
|
||||
{t(
|
||||
'1) Click "Open authorization page" and complete login. 2) Your browser may redirect to localhost (it is OK if the page does not load). 3) Copy the full URL from the address bar and paste it below. 4) Click "Generate credential".'
|
||||
)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className='flex flex-wrap gap-2'>
|
||||
<Button onClick={handleStart} disabled={state.isStarting}>
|
||||
{state.isStarting ? (
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
) : (
|
||||
<ExternalLink className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Open authorization page')}
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
disabled={!canCopyAuthorizeUrl}
|
||||
onClick={async () => {
|
||||
if (!state.authorizeUrl) return
|
||||
await copyToClipboard(state.authorizeUrl)
|
||||
}}
|
||||
aria-label={t('Copy authorization link')}
|
||||
title={t('Copy authorization link')}
|
||||
>
|
||||
{copiedText === state.authorizeUrl ? (
|
||||
<Check className='mr-2 h-4 w-4 text-green-600' />
|
||||
) : (
|
||||
<Copy className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Copy authorization link')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='space-y-2'>
|
||||
<div className='text-sm font-medium'>{t('Callback URL')}</div>
|
||||
<Input
|
||||
value={state.callbackUrl}
|
||||
onChange={(e) =>
|
||||
setState((prev) => ({ ...prev, callbackUrl: e.target.value }))
|
||||
}
|
||||
placeholder={t(
|
||||
'Paste the full callback URL (includes code & state)'
|
||||
)}
|
||||
autoComplete='off'
|
||||
spellCheck={false}
|
||||
/>
|
||||
<div className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Tip: The generated key is a JSON credential including access_token / refresh_token / account_id.'
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={t('Codex Authorization')}
|
||||
description={t(
|
||||
'Generate a Codex OAuth credential and paste it into the channel key field.'
|
||||
)}
|
||||
contentClassName='sm:max-w-2xl'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
@@ -215,8 +148,68 @@ export function CodexOAuthDialog({
|
||||
)}
|
||||
{state.isCompleting ? t('Generating...') : t('Generate credential')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='space-y-4'>
|
||||
<Alert>
|
||||
<AlertDescription>
|
||||
{t(
|
||||
'1) Click "Open authorization page" and complete login. 2) Your browser may redirect to localhost (it is OK if the page does not load). 3) Copy the full URL from the address bar and paste it below. 4) Click "Generate credential".'
|
||||
)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className='flex flex-wrap gap-2'>
|
||||
<Button onClick={handleStart} disabled={state.isStarting}>
|
||||
{state.isStarting ? (
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
) : (
|
||||
<ExternalLink className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Open authorization page')}
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
disabled={!canCopyAuthorizeUrl}
|
||||
onClick={async () => {
|
||||
if (!state.authorizeUrl) return
|
||||
await copyToClipboard(state.authorizeUrl)
|
||||
}}
|
||||
aria-label={t('Copy authorization link')}
|
||||
title={t('Copy authorization link')}
|
||||
>
|
||||
{copiedText === state.authorizeUrl ? (
|
||||
<Check className='mr-2 h-4 w-4 text-green-600' />
|
||||
) : (
|
||||
<Copy className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Copy authorization link')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='space-y-2'>
|
||||
<div className='text-sm font-medium'>{t('Callback URL')}</div>
|
||||
<Input
|
||||
value={state.callbackUrl}
|
||||
onChange={(e) =>
|
||||
setState((prev) => ({ ...prev, callbackUrl: e.target.value }))
|
||||
}
|
||||
placeholder={t(
|
||||
'Paste the full callback URL (includes code & state)'
|
||||
)}
|
||||
autoComplete='off'
|
||||
spellCheck={false}
|
||||
/>
|
||||
<div className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Tip: The generated key is a JSON credential including access_token / refresh_token / account_id.'
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+178
-181
@@ -31,16 +31,9 @@ import { useTranslation } from 'react-i18next'
|
||||
import dayjs from '@/lib/dayjs'
|
||||
import { useCopyToClipboard } from '@/hooks/use-copy-to-clipboard'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { StatusBadge, type StatusBadgeProps } from '@/components/status-badge'
|
||||
|
||||
type CodexRateLimitWindow = {
|
||||
@@ -414,177 +407,23 @@ export function CodexUsageDialog({
|
||||
}, [response])
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className='sm:max-w-3xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle className='flex items-center gap-2'>
|
||||
{t('Codex Account & Usage')}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Channel:')} <strong>{channelName || '-'}</strong>{' '}
|
||||
{channelId ? `(#${channelId})` : ''}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='space-y-4'>
|
||||
{errorMessage && (
|
||||
<div className='rounded-lg border border-red-200 bg-red-50 px-4 py-3 text-sm text-red-700 dark:border-red-800 dark:bg-red-950/30 dark:text-red-400'>
|
||||
{errorMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Account summary */}
|
||||
<div className='rounded-lg border p-4'>
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<div className='flex flex-wrap items-center gap-2'>
|
||||
<StatusBadge
|
||||
label={accountBadge.label}
|
||||
variant={accountBadge.variant}
|
||||
copyable={false}
|
||||
/>
|
||||
{statusBadge}
|
||||
{typeof response?.upstream_status === 'number' && (
|
||||
<StatusBadge
|
||||
label={`${t('Status:')} ${response.upstream_status}`}
|
||||
variant='neutral'
|
||||
copyable={false}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{onRefresh && (
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={onRefresh}
|
||||
disabled={Boolean(isRefreshing)}
|
||||
>
|
||||
<RefreshCw className='mr-1.5 h-3.5 w-3.5' />
|
||||
{t('Refresh')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Account identity info */}
|
||||
<div className='bg-muted/30 mt-3 rounded-md px-3 py-2'>
|
||||
<CopyableField
|
||||
icon={<User className='h-3.5 w-3.5' />}
|
||||
label='User ID'
|
||||
value={payload?.user_id}
|
||||
mono
|
||||
/>
|
||||
<CopyableField
|
||||
icon={<Mail className='h-3.5 w-3.5' />}
|
||||
label={t('Email')}
|
||||
value={payload?.email}
|
||||
/>
|
||||
<CopyableField
|
||||
icon={<Hash className='h-3.5 w-3.5' />}
|
||||
label='Account ID'
|
||||
value={payload?.account_id}
|
||||
mono
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Rate limit windows */}
|
||||
<div className='space-y-5'>
|
||||
<div>
|
||||
<div className='mb-1 text-sm font-medium'>
|
||||
{t('Rate Limit Windows')}
|
||||
</div>
|
||||
<p className='text-muted-foreground mb-3 text-xs'>
|
||||
{t(
|
||||
'Tracks current account base limits and additional metered usage on Codex upstream.'
|
||||
)}
|
||||
</p>
|
||||
<RateLimitGroupSection
|
||||
title={t('Base Limits')}
|
||||
description={t('Base rate limit windows for this account.')}
|
||||
source={payload}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{additionalRateLimits.length > 0 && (
|
||||
<div className='space-y-4 border-t pt-4'>
|
||||
<div>
|
||||
<div className='text-sm font-medium'>
|
||||
{t('Additional Limits')}
|
||||
</div>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Per-feature metered windows split by model or capability.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<div className='space-y-4'>
|
||||
{additionalRateLimits.map((item, index) => {
|
||||
const limitName =
|
||||
item.limit_name ||
|
||||
item.metered_feature ||
|
||||
`${t('Additional Limit')} ${index + 1}`
|
||||
return (
|
||||
<div
|
||||
key={`${limitName}-${item.metered_feature ?? ''}-${index}`}
|
||||
className={index > 0 ? 'border-t pt-4' : ''}
|
||||
>
|
||||
<RateLimitGroupSection
|
||||
title={limitName}
|
||||
description={t('Additional metered capability')}
|
||||
source={item}
|
||||
meteredFeature={item.metered_feature}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Raw JSON collapsible */}
|
||||
<div className='rounded-lg border'>
|
||||
<button
|
||||
type='button'
|
||||
className='hover:bg-muted/40 flex w-full items-center justify-between gap-2 p-3 transition-colors'
|
||||
onClick={() => setShowRawJson((v) => !v)}
|
||||
>
|
||||
<div className='text-sm font-medium'>{t('Raw JSON')}</div>
|
||||
{showRawJson ? (
|
||||
<ChevronUp className='text-muted-foreground h-4 w-4' />
|
||||
) : (
|
||||
<ChevronDown className='text-muted-foreground h-4 w-4' />
|
||||
)}
|
||||
</button>
|
||||
{showRawJson && (
|
||||
<>
|
||||
<div className='flex justify-end border-t px-3 py-2'>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => copyToClipboard(rawJsonText)}
|
||||
disabled={!rawJsonText}
|
||||
>
|
||||
{copiedText === rawJsonText ? (
|
||||
<Check className='mr-1.5 h-3.5 w-3.5 text-green-600' />
|
||||
) : (
|
||||
<Copy className='mr-1.5 h-3.5 w-3.5' />
|
||||
)}
|
||||
{t('Copy')}
|
||||
</Button>
|
||||
</div>
|
||||
<ScrollArea className='max-h-[50vh]'>
|
||||
<pre className='bg-muted/30 m-0 p-3 text-xs break-words whitespace-pre-wrap'>
|
||||
{rawJsonText || '-'}
|
||||
</pre>
|
||||
</ScrollArea>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={t('Codex Account & Usage')}
|
||||
description={
|
||||
<>
|
||||
{t('Channel:')}
|
||||
<strong>{channelName || '-'}</strong>{' '}
|
||||
{channelId ? `(#${channelId})` : ''}
|
||||
</>
|
||||
}
|
||||
contentClassName='sm:max-w-3xl'
|
||||
titleClassName='flex items-center gap-2'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
@@ -592,8 +431,166 @@ export function CodexUsageDialog({
|
||||
>
|
||||
{t('Close')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='space-y-4'>
|
||||
{errorMessage && (
|
||||
<div className='rounded-lg border border-red-200 bg-red-50 px-4 py-3 text-sm text-red-700 dark:border-red-800 dark:bg-red-950/30 dark:text-red-400'>
|
||||
{errorMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Account summary */}
|
||||
<div className='rounded-lg border p-4'>
|
||||
<div className='flex flex-wrap items-center justify-between gap-2'>
|
||||
<div className='flex flex-wrap items-center gap-2'>
|
||||
<StatusBadge
|
||||
label={accountBadge.label}
|
||||
variant={accountBadge.variant}
|
||||
copyable={false}
|
||||
/>
|
||||
{statusBadge}
|
||||
{typeof response?.upstream_status === 'number' && (
|
||||
<StatusBadge
|
||||
label={`${t('Status:')} ${response.upstream_status}`}
|
||||
variant='neutral'
|
||||
copyable={false}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{onRefresh && (
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={onRefresh}
|
||||
disabled={Boolean(isRefreshing)}
|
||||
>
|
||||
<RefreshCw className='mr-1.5 h-3.5 w-3.5' />
|
||||
{t('Refresh')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Account identity info */}
|
||||
<div className='bg-muted/30 mt-3 rounded-md px-3 py-2'>
|
||||
<CopyableField
|
||||
icon={<User className='h-3.5 w-3.5' />}
|
||||
label='User ID'
|
||||
value={payload?.user_id}
|
||||
mono
|
||||
/>
|
||||
<CopyableField
|
||||
icon={<Mail className='h-3.5 w-3.5' />}
|
||||
label={t('Email')}
|
||||
value={payload?.email}
|
||||
/>
|
||||
<CopyableField
|
||||
icon={<Hash className='h-3.5 w-3.5' />}
|
||||
label='Account ID'
|
||||
value={payload?.account_id}
|
||||
mono
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Rate limit windows */}
|
||||
<div className='space-y-5'>
|
||||
<div>
|
||||
<div className='mb-1 text-sm font-medium'>
|
||||
{t('Rate Limit Windows')}
|
||||
</div>
|
||||
<p className='text-muted-foreground mb-3 text-xs'>
|
||||
{t(
|
||||
'Tracks current account base limits and additional metered usage on Codex upstream.'
|
||||
)}
|
||||
</p>
|
||||
<RateLimitGroupSection
|
||||
title={t('Base Limits')}
|
||||
description={t('Base rate limit windows for this account.')}
|
||||
source={payload}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{additionalRateLimits.length > 0 && (
|
||||
<div className='space-y-4 border-t pt-4'>
|
||||
<div>
|
||||
<div className='text-sm font-medium'>
|
||||
{t('Additional Limits')}
|
||||
</div>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'Per-feature metered windows split by model or capability.'
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<div className='space-y-4'>
|
||||
{additionalRateLimits.map((item, index) => {
|
||||
const limitName =
|
||||
item.limit_name ||
|
||||
item.metered_feature ||
|
||||
`${t('Additional Limit')} ${index + 1}`
|
||||
return (
|
||||
<div
|
||||
key={`${limitName}-${item.metered_feature ?? ''}-${index}`}
|
||||
className={index > 0 ? 'border-t pt-4' : ''}
|
||||
>
|
||||
<RateLimitGroupSection
|
||||
title={limitName}
|
||||
description={t('Additional metered capability')}
|
||||
source={item}
|
||||
meteredFeature={item.metered_feature}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Raw JSON collapsible */}
|
||||
<div className='rounded-lg border'>
|
||||
<button
|
||||
type='button'
|
||||
className='hover:bg-muted/40 flex w-full items-center justify-between gap-2 p-3 transition-colors'
|
||||
onClick={() => setShowRawJson((v) => !v)}
|
||||
>
|
||||
<div className='text-sm font-medium'>{t('Raw JSON')}</div>
|
||||
{showRawJson ? (
|
||||
<ChevronUp className='text-muted-foreground h-4 w-4' />
|
||||
) : (
|
||||
<ChevronDown className='text-muted-foreground h-4 w-4' />
|
||||
)}
|
||||
</button>
|
||||
{showRawJson && (
|
||||
<>
|
||||
<div className='flex justify-end border-t px-3 py-2'>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => copyToClipboard(rawJsonText)}
|
||||
disabled={!rawJsonText}
|
||||
>
|
||||
{copiedText === rawJsonText ? (
|
||||
<Check className='mr-1.5 h-3.5 w-3.5 text-green-600' />
|
||||
) : (
|
||||
<Copy className='mr-1.5 h-3.5 w-3.5' />
|
||||
)}
|
||||
{t('Copy')}
|
||||
</Button>
|
||||
</div>
|
||||
<ScrollArea className='max-h-[50vh]'>
|
||||
<pre className='bg-muted/30 m-0 p-3 text-xs break-words whitespace-pre-wrap'>
|
||||
{rawJsonText || '-'}
|
||||
</pre>
|
||||
</ScrollArea>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+47
-50
@@ -22,16 +22,9 @@ import { Loader2 } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { handleCopyChannel } from '../../lib'
|
||||
import { useChannels } from '../channels-provider'
|
||||
|
||||
@@ -74,45 +67,20 @@ export function CopyChannelDialog({
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Copy Channel')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Create a copy of:')} <strong>{currentRow.name}</strong>
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className='space-y-4 py-4'>
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='suffix'>{t('Name Suffix')}</Label>
|
||||
<Input
|
||||
id='suffix'
|
||||
placeholder={t('_copy')}
|
||||
value={suffix}
|
||||
onChange={(e) => setSuffix(e.target.value)}
|
||||
disabled={isCopying}
|
||||
/>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('New name will be:')} {currentRow.name}
|
||||
{suffix}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className='flex items-center space-x-2'>
|
||||
<Checkbox
|
||||
id='reset-balance'
|
||||
checked={resetBalance}
|
||||
onCheckedChange={(checked) => setResetBalance(!!checked)}
|
||||
disabled={isCopying}
|
||||
/>
|
||||
<Label htmlFor='reset-balance' className='text-sm font-normal'>
|
||||
{t('Reset balance and used quota')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={t('Copy Channel')}
|
||||
description={
|
||||
<>
|
||||
{t('Create a copy of:')}
|
||||
<strong>{currentRow.name}</strong>
|
||||
</>
|
||||
}
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={() => onOpenChange(false)}
|
||||
@@ -122,10 +90,39 @@ export function CopyChannelDialog({
|
||||
</Button>
|
||||
<Button onClick={handleCopy} disabled={isCopying}>
|
||||
{isCopying && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{isCopying ? 'Copying...' : 'Copy Channel'}
|
||||
{isCopying ? t('Copying...') : t('Copy Channel')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<div className='space-y-4 py-4'>
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='suffix'>{t('Name Suffix')}</Label>
|
||||
<Input
|
||||
id='suffix'
|
||||
placeholder={t('_copy')}
|
||||
value={suffix}
|
||||
onChange={(e) => setSuffix(e.target.value)}
|
||||
disabled={isCopying}
|
||||
/>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('New name will be:')} {currentRow.name}
|
||||
{suffix}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className='flex items-center space-x-2'>
|
||||
<Checkbox
|
||||
id='reset-balance'
|
||||
checked={resetBalance}
|
||||
onCheckedChange={(checked) => setResetBalance(!!checked)}
|
||||
disabled={isCopying}
|
||||
/>
|
||||
<Label htmlFor='reset-balance' className='text-sm font-normal'>
|
||||
{t('Reset balance and used quota')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+216
-220
@@ -22,14 +22,6 @@ import { Loader2 } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { toast } from 'sonner'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
@@ -43,6 +35,7 @@ import {
|
||||
} from '@/components/ui/select'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { Textarea } from '@/components/ui/textarea'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { GroupBadge } from '@/components/group-badge'
|
||||
import { StatusBadge } from '@/components/status-badge'
|
||||
import {
|
||||
@@ -222,216 +215,23 @@ export function EditTagDialog({ open, onOpenChange }: EditTagDialogProps) {
|
||||
if (!currentTag) return null
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent className='max-h-[90vh] max-w-2xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
{t('Edit Tag:')} {currentTag}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t(
|
||||
'Batch edit all channels with this tag. Leave fields empty to keep current values.'
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<ScrollArea className='max-h-[60vh] pr-4'>
|
||||
<div className='space-y-6'>
|
||||
{/* Tag Name */}
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='new-tag'>
|
||||
{t('Tag Name')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t('(Leave empty to dissolve tag)')}
|
||||
</span>
|
||||
</Label>
|
||||
<Input
|
||||
id='new-tag'
|
||||
value={newTag}
|
||||
onChange={(e) => setNewTag(e.target.value)}
|
||||
placeholder={t('Enter new tag name or leave empty')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Models */}
|
||||
<div className='space-y-2'>
|
||||
<Label>
|
||||
{t('Models')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t("(Override all channels' models)")}
|
||||
</span>
|
||||
</Label>
|
||||
|
||||
{isLoadingTagModels ? (
|
||||
<div className='flex items-center gap-2 py-4'>
|
||||
<Loader2 className='h-4 w-4 animate-spin' />
|
||||
<span className='text-muted-foreground text-sm'>
|
||||
{t('Loading current models...')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className='flex min-h-[60px] flex-wrap gap-2 rounded-md border p-3'>
|
||||
{selectedModels.length > 0 ? (
|
||||
selectedModels.map((model) => (
|
||||
<StatusBadge
|
||||
key={model}
|
||||
variant='neutral'
|
||||
className='cursor-pointer transition-opacity hover:opacity-70'
|
||||
copyable={false}
|
||||
onClick={() => handleRemoveModel(model)}
|
||||
>
|
||||
{model} ×
|
||||
</StatusBadge>
|
||||
))
|
||||
) : (
|
||||
<span className='text-muted-foreground text-sm'>
|
||||
{t('No models selected')}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Select<string>
|
||||
items={[
|
||||
...availableModels.map((model) => ({
|
||||
value: model,
|
||||
label: model,
|
||||
})),
|
||||
]}
|
||||
onValueChange={(value) => {
|
||||
if (value === null) return
|
||||
if (!selectedModels.includes(value)) {
|
||||
setSelectedModels([...selectedModels, value])
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className='flex-1'>
|
||||
<SelectValue
|
||||
placeholder={t('Add from available models...')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
<ScrollArea className='h-60'>
|
||||
{availableModels.map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
{model}
|
||||
</SelectItem>
|
||||
))}
|
||||
</ScrollArea>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Input
|
||||
placeholder={t('Custom model (comma-separated)')}
|
||||
value={customModel}
|
||||
onChange={(e) => setCustomModel(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
handleAddCustomModel()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
type='button'
|
||||
variant='secondary'
|
||||
onClick={handleAddCustomModel}
|
||||
>
|
||||
{t('Add')}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Model Mapping */}
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='model-mapping'>
|
||||
{t('Model Mapping (JSON)')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t('(Optional: redirect model names)')}
|
||||
</span>
|
||||
</Label>
|
||||
<Textarea
|
||||
id='model-mapping'
|
||||
value={modelMapping}
|
||||
onChange={(e) => setModelMapping(e.target.value)}
|
||||
placeholder={'{\n "gpt-3.5-turbo": "gpt-3.5-turbo-0125"\n}'}
|
||||
rows={4}
|
||||
className='font-mono text-sm'
|
||||
/>
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() =>
|
||||
setModelMapping(
|
||||
JSON.stringify(
|
||||
{ 'gpt-3.5-turbo': 'gpt-3.5-turbo-0125' },
|
||||
null,
|
||||
2
|
||||
)
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('Example')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => setModelMapping(JSON.stringify({}, null, 2))}
|
||||
>
|
||||
{t('Clear Mapping')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => setModelMapping('')}
|
||||
>
|
||||
{t('No Change')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Groups */}
|
||||
<div className='space-y-2'>
|
||||
<Label>
|
||||
{t('Groups')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t("(Override all channels' groups)")}
|
||||
</span>
|
||||
</Label>
|
||||
<div className='flex min-h-[60px] flex-wrap gap-2 rounded-md border p-3'>
|
||||
{availableGroups.map((group) => (
|
||||
<GroupBadge
|
||||
key={group}
|
||||
group={group}
|
||||
className={`cursor-pointer rounded-sm transition-opacity hover:opacity-70 ${
|
||||
selectedGroups.includes(group) ? 'bg-muted/70 px-1' : ''
|
||||
}`}
|
||||
onClick={() => handleToggleGroup(group)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
|
||||
<DialogFooter>
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={handleClose}
|
||||
title={
|
||||
<>
|
||||
{t('Edit Tag:')}
|
||||
{currentTag}
|
||||
</>
|
||||
}
|
||||
description={t(
|
||||
'Batch edit all channels with this tag. Leave fields empty to keep current values.'
|
||||
)}
|
||||
contentClassName='max-h-[90vh] max-w-2xl'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<>
|
||||
<Button variant='outline' onClick={handleClose}>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
@@ -439,8 +239,204 @@ export function EditTagDialog({ open, onOpenChange }: EditTagDialogProps) {
|
||||
{isSubmitting && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{t('Save Changes')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<ScrollArea className='max-h-[60vh] pr-4'>
|
||||
<div className='space-y-6'>
|
||||
{/* Tag Name */}
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='new-tag'>
|
||||
{t('Tag Name')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t('(Leave empty to dissolve tag)')}
|
||||
</span>
|
||||
</Label>
|
||||
<Input
|
||||
id='new-tag'
|
||||
value={newTag}
|
||||
onChange={(e) => setNewTag(e.target.value)}
|
||||
placeholder={t('Enter new tag name or leave empty')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Models */}
|
||||
<div className='space-y-2'>
|
||||
<Label>
|
||||
{t('Models')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t("(Override all channels' models)")}
|
||||
</span>
|
||||
</Label>
|
||||
|
||||
{isLoadingTagModels ? (
|
||||
<div className='flex items-center gap-2 py-4'>
|
||||
<Loader2 className='h-4 w-4 animate-spin' />
|
||||
<span className='text-muted-foreground text-sm'>
|
||||
{t('Loading current models...')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className='flex min-h-[60px] flex-wrap gap-2 rounded-md border p-3'>
|
||||
{selectedModels.length > 0 ? (
|
||||
selectedModels.map((model) => (
|
||||
<StatusBadge
|
||||
key={model}
|
||||
variant='neutral'
|
||||
className='cursor-pointer transition-opacity hover:opacity-70'
|
||||
copyable={false}
|
||||
onClick={() => handleRemoveModel(model)}
|
||||
>
|
||||
{model} ×
|
||||
</StatusBadge>
|
||||
))
|
||||
) : (
|
||||
<span className='text-muted-foreground text-sm'>
|
||||
{t('No models selected')}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Select<string>
|
||||
items={[
|
||||
...availableModels.map((model) => ({
|
||||
value: model,
|
||||
label: model,
|
||||
})),
|
||||
]}
|
||||
onValueChange={(value) => {
|
||||
if (value === null) return
|
||||
if (!selectedModels.includes(value)) {
|
||||
setSelectedModels([...selectedModels, value])
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className='flex-1'>
|
||||
<SelectValue
|
||||
placeholder={t('Add from available models...')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
<ScrollArea className='h-60'>
|
||||
{availableModels.map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
{model}
|
||||
</SelectItem>
|
||||
))}
|
||||
</ScrollArea>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Input
|
||||
placeholder={t('Custom model (comma-separated)')}
|
||||
value={customModel}
|
||||
onChange={(e) => setCustomModel(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
handleAddCustomModel()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
type='button'
|
||||
variant='secondary'
|
||||
onClick={handleAddCustomModel}
|
||||
>
|
||||
{t('Add')}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Model Mapping */}
|
||||
<div className='space-y-2'>
|
||||
<Label htmlFor='model-mapping'>
|
||||
{t('Model Mapping (JSON)')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t('(Optional: redirect model names)')}
|
||||
</span>
|
||||
</Label>
|
||||
<Textarea
|
||||
id='model-mapping'
|
||||
value={modelMapping}
|
||||
onChange={(e) => setModelMapping(e.target.value)}
|
||||
placeholder={'{\n "gpt-3.5-turbo": "gpt-3.5-turbo-0125"\n}'}
|
||||
rows={4}
|
||||
className='font-mono text-sm'
|
||||
/>
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() =>
|
||||
setModelMapping(
|
||||
JSON.stringify(
|
||||
{ 'gpt-3.5-turbo': 'gpt-3.5-turbo-0125' },
|
||||
null,
|
||||
2
|
||||
)
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('Example')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => setModelMapping(JSON.stringify({}, null, 2))}
|
||||
>
|
||||
{t('Clear Mapping')}
|
||||
</Button>
|
||||
<Button
|
||||
type='button'
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => setModelMapping('')}
|
||||
>
|
||||
{t('No Change')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Groups */}
|
||||
<div className='space-y-2'>
|
||||
<Label>
|
||||
{t('Groups')}
|
||||
<span className='text-muted-foreground ml-2 text-xs'>
|
||||
{t("(Override all channels' groups)")}
|
||||
</span>
|
||||
</Label>
|
||||
<div className='flex min-h-[60px] flex-wrap gap-2 rounded-md border p-3'>
|
||||
{availableGroups.map((group) => (
|
||||
<GroupBadge
|
||||
key={group}
|
||||
group={group}
|
||||
className={`cursor-pointer rounded-sm transition-opacity hover:opacity-70 ${
|
||||
selectedGroups.includes(group) ? 'bg-muted/70 px-1' : ''
|
||||
}`}
|
||||
onClick={() => handleToggleGroup(group)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+143
-149
@@ -28,14 +28,6 @@ import {
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from '@/components/ui/collapsible'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||
@@ -44,6 +36,7 @@ import {
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { fetchUpstreamModels, updateChannel } from '../../api'
|
||||
import {
|
||||
channelsQueryKeys,
|
||||
@@ -365,152 +358,153 @@ export function FetchModelsDialog({
|
||||
)
|
||||
}
|
||||
|
||||
const showFooterActions =
|
||||
!!(activeChannel || customFetcher) &&
|
||||
!isFetching &&
|
||||
(fetchedModels.length > 0 || removedModels.length > 0)
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent className='max-w-3xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Fetch Models')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{activeChannel ? (
|
||||
<>
|
||||
{t('Fetch available models for:')}{' '}
|
||||
<strong>{activeChannel.name}</strong>
|
||||
</>
|
||||
) : channelName ? (
|
||||
<>
|
||||
{t('Fetch available models for:')}{' '}
|
||||
<strong>{channelName}</strong>
|
||||
</>
|
||||
) : (
|
||||
t('Fetch available models from upstream')
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{!activeChannel && !customFetcher ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
{t('No channel selected')}
|
||||
</div>
|
||||
) : isFetching ? (
|
||||
<div className='flex items-center justify-center py-12'>
|
||||
<Loader2 className='text-muted-foreground h-8 w-8 animate-spin' />
|
||||
</div>
|
||||
) : fetchedModels.length === 0 && removedModels.length === 0 ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
<p>{t('No models fetched yet.')}</p>
|
||||
<Button
|
||||
className='mt-4'
|
||||
onClick={handleFetchModels}
|
||||
disabled={isFetching}
|
||||
>
|
||||
{t('Fetch Models')}
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={handleClose}
|
||||
title={t('Fetch Models')}
|
||||
description={
|
||||
activeChannel ? (
|
||||
<>
|
||||
<div className='space-y-4'>
|
||||
{/* Search Bar */}
|
||||
<div className='relative'>
|
||||
<Search className='text-muted-foreground absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2' />
|
||||
<Input
|
||||
placeholder={t('Search models...')}
|
||||
value={searchKeyword}
|
||||
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||
className='pl-9'
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Tabs for New vs Existing vs Removed */}
|
||||
<Tabs
|
||||
key={`${activeChannel?.id ?? 'custom'}-${fetchedModels.length}-${removedModels.length}`}
|
||||
defaultValue={
|
||||
newModels.length > 0
|
||||
? 'new'
|
||||
: removedModels.length > 0
|
||||
? 'removed'
|
||||
: 'existing'
|
||||
}
|
||||
>
|
||||
<TabsList
|
||||
className={`grid w-full ${removedModels.length > 0 ? 'grid-cols-3' : 'grid-cols-2'}`}
|
||||
>
|
||||
<TabsTrigger value='new' disabled={newModels.length === 0}>
|
||||
{t('New Models ({{count}})', { count: newModels.length })}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value='existing'
|
||||
disabled={existingFilteredModels.length === 0}
|
||||
>
|
||||
{t('Existing Models ({{count}})', {
|
||||
count: existingFilteredModels.length,
|
||||
})}
|
||||
</TabsTrigger>
|
||||
{removedModels.length > 0 && (
|
||||
<TabsTrigger value='removed'>
|
||||
{t('Removed Models ({{count}})', {
|
||||
count: removedModels.length,
|
||||
})}
|
||||
</TabsTrigger>
|
||||
)}
|
||||
</TabsList>
|
||||
|
||||
<TabsContent
|
||||
value='new'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
{getSortedCategoryEntries(newModelsByCategory).map(
|
||||
([category, models]) =>
|
||||
renderModelCategory(category, models)
|
||||
)}
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent
|
||||
value='existing'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
{getSortedCategoryEntries(existingModelsByCategory).map(
|
||||
([category, models]) =>
|
||||
renderModelCategory(category, models)
|
||||
)}
|
||||
</TabsContent>
|
||||
|
||||
{removedModels.length > 0 && (
|
||||
<TabsContent
|
||||
value='removed'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'These models are still in your selection but were not returned by the upstream listing. Entries that are only model_mapping source aliases are omitted. Toggle to adjust before saving.'
|
||||
)}
|
||||
</p>
|
||||
{renderModelCategory(t('Removed'), removedModels)}
|
||||
</TabsContent>
|
||||
)}
|
||||
</Tabs>
|
||||
|
||||
{/* Selection Summary */}
|
||||
<div className='bg-muted/50 rounded-lg border p-3 text-sm'>
|
||||
{t('{{n}} model(s) selected', { n: selectedModels.length })}
|
||||
</div>
|
||||
{t('Channel:')} <strong>{activeChannel.name}</strong>
|
||||
</>
|
||||
) : channelName ? (
|
||||
<>
|
||||
{t('Channel:')} <strong>{channelName}</strong>
|
||||
</>
|
||||
) : (
|
||||
t('Fetch available models from upstream')
|
||||
)
|
||||
}
|
||||
contentClassName='max-w-3xl'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
showFooterActions ? (
|
||||
<>
|
||||
<Button variant='outline' onClick={handleClose} disabled={isSaving}>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={isSaving}>
|
||||
{isSaving && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{isSaving ? t('Saving...') : t('Save Models')}
|
||||
</Button>
|
||||
</>
|
||||
) : null
|
||||
}
|
||||
>
|
||||
{!activeChannel && !customFetcher ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
{t('No channel selected')}
|
||||
</div>
|
||||
) : isFetching ? (
|
||||
<div className='flex items-center justify-center py-12'>
|
||||
<Loader2 className='text-muted-foreground h-8 w-8 animate-spin' />
|
||||
</div>
|
||||
) : fetchedModels.length === 0 && removedModels.length === 0 ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
<p>{t('No models fetched yet.')}</p>
|
||||
<Button
|
||||
className='mt-4'
|
||||
onClick={handleFetchModels}
|
||||
disabled={isFetching}
|
||||
>
|
||||
{t('Fetch Models')}
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className='space-y-4'>
|
||||
{/* Search Bar */}
|
||||
<div className='relative'>
|
||||
<Search className='text-muted-foreground absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2' />
|
||||
<Input
|
||||
placeholder={t('Search models...')}
|
||||
value={searchKeyword}
|
||||
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||
className='pl-9'
|
||||
/>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={handleClose}
|
||||
disabled={isSaving}
|
||||
{/* Tabs for New vs Existing vs Removed */}
|
||||
<Tabs
|
||||
key={`${activeChannel?.id ?? 'custom'}-${fetchedModels.length}-${removedModels.length}`}
|
||||
defaultValue={
|
||||
newModels.length > 0
|
||||
? 'new'
|
||||
: removedModels.length > 0
|
||||
? 'removed'
|
||||
: 'existing'
|
||||
}
|
||||
>
|
||||
<TabsList
|
||||
className={`grid w-full ${removedModels.length > 0 ? 'grid-cols-3' : 'grid-cols-2'}`}
|
||||
>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={isSaving}>
|
||||
{isSaving && <Loader2 className='mr-2 h-4 w-4 animate-spin' />}
|
||||
{isSaving ? t('Saving...') : t('Save Models')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</>
|
||||
)}
|
||||
</DialogContent>
|
||||
<TabsTrigger value='new' disabled={newModels.length === 0}>
|
||||
{t('New Models ({{count}})', { count: newModels.length })}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value='existing'
|
||||
disabled={existingFilteredModels.length === 0}
|
||||
>
|
||||
{t('Existing Models ({{count}})', {
|
||||
count: existingFilteredModels.length,
|
||||
})}
|
||||
</TabsTrigger>
|
||||
{removedModels.length > 0 && (
|
||||
<TabsTrigger value='removed'>
|
||||
{t('Removed Models ({{count}})', {
|
||||
count: removedModels.length,
|
||||
})}
|
||||
</TabsTrigger>
|
||||
)}
|
||||
</TabsList>
|
||||
|
||||
<TabsContent
|
||||
value='new'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
{getSortedCategoryEntries(newModelsByCategory).map(
|
||||
([category, models]) => renderModelCategory(category, models)
|
||||
)}
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent
|
||||
value='existing'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
{getSortedCategoryEntries(existingModelsByCategory).map(
|
||||
([category, models]) => renderModelCategory(category, models)
|
||||
)}
|
||||
</TabsContent>
|
||||
|
||||
{removedModels.length > 0 && (
|
||||
<TabsContent
|
||||
value='removed'
|
||||
className='max-h-96 space-y-2 overflow-y-auto'
|
||||
>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t(
|
||||
'These models are still in your selection but were not returned by the upstream listing. Entries that are only model_mapping source aliases are omitted. Toggle to adjust before saving.'
|
||||
)}
|
||||
</p>
|
||||
{renderModelCategory(t('Removed'), removedModels)}
|
||||
</TabsContent>
|
||||
)}
|
||||
</Tabs>
|
||||
|
||||
{/* Selection Summary */}
|
||||
<div className='bg-muted/50 rounded-lg border p-3 text-sm'>
|
||||
{t('{{n}} model(s) selected', { n: selectedModels.length })}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
+202
-208
@@ -22,13 +22,6 @@ import { Loader2, RefreshCw, Trash2, Power, PowerOff } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { toast } from 'sonner'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -38,15 +31,9 @@ import {
|
||||
SelectValue,
|
||||
} from '@/components/ui/select'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from '@/components/ui/table'
|
||||
import { ConfirmDialog } from '@/components/confirm-dialog'
|
||||
import { StaticDataTable } from '@/components/data-table'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import { StatusBadge } from '@/components/status-badge'
|
||||
import {
|
||||
getMultiKeyStatus,
|
||||
@@ -228,215 +215,222 @@ export function MultiKeyManageDialog({
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className='flex max-h-[90vh] max-w-5xl flex-col'>
|
||||
<DialogHeader>
|
||||
<DialogTitle className='flex items-center gap-2'>
|
||||
{t('Multi-Key Management')}
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={
|
||||
<>
|
||||
{t('Multi-Key Management')}
|
||||
<StatusBadge
|
||||
label={currentRow.name}
|
||||
variant='neutral'
|
||||
copyable={false}
|
||||
/>
|
||||
{currentRow.channel_info?.multi_key_mode && (
|
||||
<StatusBadge
|
||||
label={currentRow.name}
|
||||
label={
|
||||
currentRow.channel_info.multi_key_mode === 'random'
|
||||
? t('Random')
|
||||
: t('Polling')
|
||||
}
|
||||
variant='neutral'
|
||||
copyable={false}
|
||||
/>
|
||||
{currentRow.channel_info?.multi_key_mode && (
|
||||
<StatusBadge
|
||||
label={
|
||||
currentRow.channel_info.multi_key_mode === 'random'
|
||||
? t('Random')
|
||||
: t('Polling')
|
||||
}
|
||||
variant='neutral'
|
||||
copyable={false}
|
||||
/>
|
||||
)}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Manage multi-key status and configuration for this channel')}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
)}
|
||||
</>
|
||||
}
|
||||
description={t(
|
||||
'Manage multi-key status and configuration for this channel'
|
||||
)}
|
||||
contentClassName='flex max-h-[90vh] max-w-5xl flex-col'
|
||||
titleClassName='flex items-center gap-2'
|
||||
contentHeight='min(72vh, 720px)'
|
||||
bodyClassName='space-y-4'
|
||||
>
|
||||
<div className='flex min-h-0 flex-1 flex-col space-y-4 overflow-hidden'>
|
||||
{/* Statistics */}
|
||||
<div className='grid shrink-0 grid-cols-3 gap-3'>
|
||||
<StatisticsCard
|
||||
label={t('Enabled')}
|
||||
count={enabledCount}
|
||||
total={total}
|
||||
/>
|
||||
<StatisticsCard
|
||||
label={t('Manual Disabled')}
|
||||
count={manualDisabledCount}
|
||||
total={total}
|
||||
/>
|
||||
<StatisticsCard
|
||||
label={t('Auto Disabled')}
|
||||
count={autoDisabledCount}
|
||||
total={total}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className='flex min-h-0 flex-1 flex-col space-y-4 overflow-hidden'>
|
||||
{/* Statistics */}
|
||||
<div className='grid shrink-0 grid-cols-3 gap-3'>
|
||||
<StatisticsCard
|
||||
label={t('Enabled')}
|
||||
count={enabledCount}
|
||||
total={total}
|
||||
/>
|
||||
<StatisticsCard
|
||||
label={t('Manual Disabled')}
|
||||
count={manualDisabledCount}
|
||||
total={total}
|
||||
/>
|
||||
<StatisticsCard
|
||||
label={t('Auto Disabled')}
|
||||
count={autoDisabledCount}
|
||||
total={total}
|
||||
/>
|
||||
</div>
|
||||
<Separator className='shrink-0' />
|
||||
|
||||
<Separator className='shrink-0' />
|
||||
{/* Toolbar */}
|
||||
<div className='flex shrink-0 items-center justify-between'>
|
||||
<Select
|
||||
items={[
|
||||
...MULTI_KEY_FILTER_OPTIONS.map((option) => ({
|
||||
value: option.value,
|
||||
label: t(option.label),
|
||||
})),
|
||||
]}
|
||||
value={statusFilter === null ? 'all' : statusFilter.toString()}
|
||||
onValueChange={(v) => v !== null && handleStatusFilterChange(v)}
|
||||
>
|
||||
<SelectTrigger className='w-40'>
|
||||
<SelectValue placeholder={t('All Status')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
{MULTI_KEY_FILTER_OPTIONS.map((option) => (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
{t(option.label)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
{/* Toolbar */}
|
||||
<div className='flex shrink-0 items-center justify-between'>
|
||||
<Select
|
||||
items={[
|
||||
...MULTI_KEY_FILTER_OPTIONS.map((option) => ({
|
||||
value: option.value,
|
||||
label: t(option.label),
|
||||
})),
|
||||
]}
|
||||
value={statusFilter === null ? 'all' : statusFilter.toString()}
|
||||
onValueChange={(v) => v !== null && handleStatusFilterChange(v)}
|
||||
<div className='flex items-center gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => loadKeyStatus()}
|
||||
disabled={isLoading}
|
||||
>
|
||||
<SelectTrigger className='w-40'>
|
||||
<SelectValue placeholder={t('All Status')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent alignItemWithTrigger={false}>
|
||||
<SelectGroup>
|
||||
{MULTI_KEY_FILTER_OPTIONS.map((option) => (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
{t(option.label)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<RefreshCw className='h-4 w-4' />
|
||||
</Button>
|
||||
|
||||
<div className='flex items-center gap-2'>
|
||||
{manualDisabledCount + autoDisabledCount > 0 && (
|
||||
<Button
|
||||
variant='default'
|
||||
size='sm'
|
||||
onClick={() => setConfirmAction({ type: 'enable-all' })}
|
||||
>
|
||||
<Power className='mr-2 h-4 w-4' />
|
||||
{t('Enable All')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{enabledCount > 0 && (
|
||||
<Button
|
||||
variant='destructive'
|
||||
size='sm'
|
||||
onClick={() => setConfirmAction({ type: 'disable-all' })}
|
||||
>
|
||||
<PowerOff className='mr-2 h-4 w-4' />
|
||||
{t('Disable All')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{autoDisabledCount > 0 && (
|
||||
<Button
|
||||
variant='destructive'
|
||||
size='sm'
|
||||
onClick={() => setConfirmAction({ type: 'delete-disabled' })}
|
||||
>
|
||||
<Trash2 className='mr-2 h-4 w-4' />
|
||||
{t('Delete Auto-Disabled')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Table */}
|
||||
<div className='min-h-0 flex-1 overflow-auto rounded-md border'>
|
||||
{isLoading ? (
|
||||
<div className='flex items-center justify-center py-12'>
|
||||
<Loader2 className='text-muted-foreground h-8 w-8 animate-spin' />
|
||||
</div>
|
||||
) : keys.length === 0 ? (
|
||||
<div className='text-muted-foreground py-12 text-center'>
|
||||
{t('No keys found')}
|
||||
</div>
|
||||
) : (
|
||||
<StaticDataTable
|
||||
className='rounded-none border-0'
|
||||
tableClassName='min-w-[800px]'
|
||||
data={keys}
|
||||
getRowKey={(key) => key.index}
|
||||
columns={[
|
||||
{
|
||||
id: 'index',
|
||||
header: t('Index'),
|
||||
className: 'w-20',
|
||||
cellClassName: 'font-mono text-sm',
|
||||
cell: (key) => `#${key.index + 1}`,
|
||||
},
|
||||
{
|
||||
id: 'status',
|
||||
header: t('Status'),
|
||||
className: 'w-32',
|
||||
cell: (key) => renderStatusBadge(key.status),
|
||||
},
|
||||
{
|
||||
id: 'reason',
|
||||
header: t('Disabled Reason'),
|
||||
className: 'min-w-[200px]',
|
||||
cellClassName: 'max-w-xs truncate text-sm',
|
||||
cell: (key) => key.reason || '-',
|
||||
},
|
||||
{
|
||||
id: 'disabled-time',
|
||||
header: t('Disabled Time'),
|
||||
className: 'w-44',
|
||||
cellClassName: 'text-muted-foreground text-sm',
|
||||
cell: (key) => formatKeyTimestamp(key.disabled_time),
|
||||
},
|
||||
{
|
||||
id: 'actions',
|
||||
header: t('Actions'),
|
||||
className: 'w-44 text-right',
|
||||
cell: (key) => (
|
||||
<MultiKeyTableRowActions
|
||||
keyIndex={key.index}
|
||||
status={key.status}
|
||||
onAction={setConfirmAction}
|
||||
/>
|
||||
),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Pagination */}
|
||||
{totalPages > 1 && (
|
||||
<div className='flex shrink-0 items-center justify-between'>
|
||||
<div className='text-muted-foreground text-sm'>
|
||||
{t('Page {{current}} of {{total}}', {
|
||||
current: currentPage,
|
||||
total: totalPages,
|
||||
})}
|
||||
</div>
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => loadKeyStatus()}
|
||||
disabled={isLoading}
|
||||
onClick={() => handlePageChange(currentPage - 1)}
|
||||
disabled={currentPage === 1 || isLoading}
|
||||
>
|
||||
<RefreshCw className='h-4 w-4' />
|
||||
{t('Previous')}
|
||||
</Button>
|
||||
<Button
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => handlePageChange(currentPage + 1)}
|
||||
disabled={currentPage >= totalPages || isLoading}
|
||||
>
|
||||
{t('Next')}
|
||||
</Button>
|
||||
|
||||
{manualDisabledCount + autoDisabledCount > 0 && (
|
||||
<Button
|
||||
variant='default'
|
||||
size='sm'
|
||||
onClick={() => setConfirmAction({ type: 'enable-all' })}
|
||||
>
|
||||
<Power className='mr-2 h-4 w-4' />
|
||||
{t('Enable All')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{enabledCount > 0 && (
|
||||
<Button
|
||||
variant='destructive'
|
||||
size='sm'
|
||||
onClick={() => setConfirmAction({ type: 'disable-all' })}
|
||||
>
|
||||
<PowerOff className='mr-2 h-4 w-4' />
|
||||
{t('Disable All')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{autoDisabledCount > 0 && (
|
||||
<Button
|
||||
variant='destructive'
|
||||
size='sm'
|
||||
onClick={() =>
|
||||
setConfirmAction({ type: 'delete-disabled' })
|
||||
}
|
||||
>
|
||||
<Trash2 className='mr-2 h-4 w-4' />
|
||||
{t('Delete Auto-Disabled')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Table */}
|
||||
<div className='min-h-0 flex-1 overflow-auto rounded-md border'>
|
||||
{isLoading ? (
|
||||
<div className='flex items-center justify-center py-12'>
|
||||
<Loader2 className='text-muted-foreground h-8 w-8 animate-spin' />
|
||||
</div>
|
||||
) : keys.length === 0 ? (
|
||||
<div className='text-muted-foreground py-12 text-center'>
|
||||
{t('No keys found')}
|
||||
</div>
|
||||
) : (
|
||||
<div className='min-w-[800px]'>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead className='w-20'>{t('Index')}</TableHead>
|
||||
<TableHead className='w-32'>{t('Status')}</TableHead>
|
||||
<TableHead className='min-w-[200px]'>
|
||||
{t('Disabled Reason')}
|
||||
</TableHead>
|
||||
<TableHead className='w-44'>
|
||||
{t('Disabled Time')}
|
||||
</TableHead>
|
||||
<TableHead className='w-44 text-right'>
|
||||
{t('Actions')}
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{keys.map((key) => (
|
||||
<TableRow key={key.index}>
|
||||
<TableCell className='font-mono text-sm'>
|
||||
#{key.index + 1}
|
||||
</TableCell>
|
||||
<TableCell>{renderStatusBadge(key.status)}</TableCell>
|
||||
<TableCell className='max-w-xs truncate text-sm'>
|
||||
{key.reason || '-'}
|
||||
</TableCell>
|
||||
<TableCell className='text-muted-foreground text-sm'>
|
||||
{formatKeyTimestamp(key.disabled_time)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<MultiKeyTableRowActions
|
||||
keyIndex={key.index}
|
||||
status={key.status}
|
||||
onAction={setConfirmAction}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Pagination */}
|
||||
{totalPages > 1 && (
|
||||
<div className='flex shrink-0 items-center justify-between'>
|
||||
<div className='text-muted-foreground text-sm'>
|
||||
{t('Page {{current}} of {{total}}', {
|
||||
current: currentPage,
|
||||
total: totalPages,
|
||||
})}
|
||||
</div>
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => handlePageChange(currentPage - 1)}
|
||||
disabled={currentPage === 1 || isLoading}
|
||||
>
|
||||
{t('Previous')}
|
||||
</Button>
|
||||
<Button
|
||||
variant='outline'
|
||||
size='sm'
|
||||
onClick={() => handlePageChange(currentPage + 1)}
|
||||
disabled={currentPage >= totalPages || isLoading}
|
||||
>
|
||||
{t('Next')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
)}
|
||||
</div>
|
||||
</Dialog>
|
||||
|
||||
{/* Confirmation Dialog */}
|
||||
|
||||
+192
-199
@@ -34,18 +34,11 @@ import {
|
||||
} from '@/components/ui/alert-dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { Dialog } from '@/components/dialog'
|
||||
import {
|
||||
deleteOllamaModel,
|
||||
fetchModels as fetchModelsFromEndpoint,
|
||||
@@ -375,203 +368,203 @@ export function OllamaModelsDialog({
|
||||
if (!open) return null
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={close}>
|
||||
<DialogContent className='max-h-[90vh] overflow-hidden sm:max-w-3xl'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('Ollama Models')}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('Manage local models for:')} <strong>{currentRow?.name}</strong>
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{!isOllamaChannel ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
{t('This channel is not an Ollama channel.')}
|
||||
</div>
|
||||
) : (
|
||||
<div className='max-h-[78vh] space-y-4 overflow-y-auto py-2 pr-1'>
|
||||
<div className='flex flex-col gap-3 sm:flex-row sm:items-end sm:justify-between'>
|
||||
<div className='flex-1 space-y-2'>
|
||||
<Label htmlFor='ollama-pull'>{t('Pull model')}</Label>
|
||||
<div className='flex gap-2'>
|
||||
<Input
|
||||
id='ollama-pull'
|
||||
placeholder={t('e.g. llama3.1:8b')}
|
||||
value={pullName}
|
||||
onChange={(e) => setPullName(e.target.value)}
|
||||
disabled={!channelId || isPulling}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => void pullModel()}
|
||||
disabled={!channelId || isPulling}
|
||||
>
|
||||
{isPulling ? (
|
||||
<>
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
{t('Pulling...')}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Download className='mr-2 h-4 w-4' />
|
||||
{t('Pull')}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{pullProgress && (
|
||||
<div className='space-y-2'>
|
||||
<div className='text-muted-foreground text-xs'>
|
||||
{t('Status:')} {String(pullProgress.status || '-')}
|
||||
</div>
|
||||
<Progress
|
||||
value={
|
||||
typeof pullProgress.completed === 'number' &&
|
||||
typeof pullProgress.total === 'number' &&
|
||||
pullProgress.total > 0
|
||||
? Math.min(
|
||||
100,
|
||||
Math.round(
|
||||
(pullProgress.completed / pullProgress.total) *
|
||||
100
|
||||
)
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={close}
|
||||
title={t('Ollama Models')}
|
||||
description={
|
||||
<>
|
||||
{t('Manage local models for:')} <strong>{currentRow?.name}</strong>
|
||||
</>
|
||||
}
|
||||
contentClassName='sm:max-w-3xl'
|
||||
contentHeight='auto'
|
||||
bodyClassName='space-y-4'
|
||||
footer={
|
||||
<Button variant='outline' onClick={close}>
|
||||
{t('Close')}
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{!isOllamaChannel ? (
|
||||
<div className='text-muted-foreground py-8 text-center'>
|
||||
{t('This channel is not an Ollama channel.')}
|
||||
</div>
|
||||
) : (
|
||||
<div className='space-y-4 py-2 pr-1'>
|
||||
<div className='flex flex-col gap-3 sm:flex-row sm:items-end sm:justify-between'>
|
||||
<div className='flex-1 space-y-2'>
|
||||
<Label htmlFor='ollama-pull'>{t('Pull model')}</Label>
|
||||
<div className='flex gap-2'>
|
||||
<Input
|
||||
id='ollama-pull'
|
||||
placeholder={t('e.g. llama3.1:8b')}
|
||||
value={pullName}
|
||||
onChange={(e) => setPullName(e.target.value)}
|
||||
disabled={!channelId || isPulling}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => void pullModel()}
|
||||
disabled={!channelId || isPulling}
|
||||
>
|
||||
{isPulling ? (
|
||||
<>
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
{t('Pulling...')}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Download className='mr-2 h-4 w-4' />
|
||||
{t('Pull')}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{pullProgress && (
|
||||
<div className='space-y-2'>
|
||||
<div className='text-muted-foreground text-xs'>
|
||||
{t('Status:')} {String(pullProgress.status || '-')}
|
||||
</div>
|
||||
<Progress
|
||||
value={
|
||||
typeof pullProgress.completed === 'number' &&
|
||||
typeof pullProgress.total === 'number' &&
|
||||
pullProgress.total > 0
|
||||
? Math.min(
|
||||
100,
|
||||
Math.round(
|
||||
(pullProgress.completed / pullProgress.total) *
|
||||
100
|
||||
)
|
||||
: 0
|
||||
}
|
||||
/>
|
||||
)
|
||||
: 0
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={() => void fetchOllamaModels()}
|
||||
disabled={!channelId || isFetching}
|
||||
>
|
||||
{isFetching ? (
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
) : (
|
||||
<RefreshCw className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Refresh')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className='space-y-3'>
|
||||
<div className='flex flex-col gap-2 sm:flex-row sm:items-center sm:justify-between'>
|
||||
<div>
|
||||
<p className='text-sm font-medium'>{t('Local models')}</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Select models and apply to channel models list.')}
|
||||
</p>
|
||||
</div>
|
||||
<div className='relative sm:w-72'>
|
||||
<Search className='text-muted-foreground absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2' />
|
||||
<Input
|
||||
placeholder={t('Search models...')}
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
className='pl-9'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex flex-wrap gap-2'>
|
||||
<Button variant='outline' size='sm' onClick={selectAllFiltered}>
|
||||
{t('Select all (filtered)')}
|
||||
</Button>
|
||||
<Button variant='outline' size='sm' onClick={clearSelection}>
|
||||
{t('Clear selection')}
|
||||
</Button>
|
||||
<Button
|
||||
size='sm'
|
||||
onClick={() => void applySelection('append')}
|
||||
disabled={!selected.length}
|
||||
>
|
||||
{t('Append to channel')}
|
||||
</Button>
|
||||
<Button
|
||||
variant='secondary'
|
||||
size='sm'
|
||||
onClick={() => void applySelection('replace')}
|
||||
disabled={!selected.length}
|
||||
>
|
||||
{t('Replace channel models')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='overflow-hidden rounded-md border'>
|
||||
<div className='max-h-[420px] overflow-y-auto'>
|
||||
{filteredModels.length === 0 ? (
|
||||
<div className='text-muted-foreground p-6 text-center text-sm'>
|
||||
{t('No models found.')}
|
||||
</div>
|
||||
) : (
|
||||
<div className='divide-y'>
|
||||
{filteredModels.map((m) => {
|
||||
const checked = selected.includes(m.id)
|
||||
return (
|
||||
<div
|
||||
key={m.id}
|
||||
className='flex items-center justify-between gap-3 p-3'
|
||||
>
|
||||
<div className='flex min-w-0 items-start gap-3'>
|
||||
<Checkbox
|
||||
checked={checked}
|
||||
onCheckedChange={(v) => toggleSelected(m.id, !!v)}
|
||||
aria-label={`Select model ${m.id}`}
|
||||
/>
|
||||
<div className='min-w-0'>
|
||||
<div className='truncate font-mono text-sm'>
|
||||
{m.id}
|
||||
</div>
|
||||
<div className='text-muted-foreground flex flex-wrap gap-x-3 gap-y-1 text-xs'>
|
||||
<span>
|
||||
{t('Size:')} {formatBytes(m.size)}
|
||||
</span>
|
||||
{m.digest && (
|
||||
<span className='truncate'>
|
||||
{t('Digest:')} {String(m.digest)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
className='text-destructive hover:text-destructive'
|
||||
onClick={() => {
|
||||
setDeleteTarget(m.id)
|
||||
setDeleteOpen(true)
|
||||
}}
|
||||
disabled={!channelId}
|
||||
>
|
||||
<Trash2 className='h-4 w-4' />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className='flex gap-2'>
|
||||
<Button
|
||||
variant='outline'
|
||||
onClick={() => void fetchOllamaModels()}
|
||||
disabled={!channelId || isFetching}
|
||||
>
|
||||
{isFetching ? (
|
||||
<Loader2 className='mr-2 h-4 w-4 animate-spin' />
|
||||
) : (
|
||||
<RefreshCw className='mr-2 h-4 w-4' />
|
||||
)}
|
||||
{t('Refresh')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className='space-y-3'>
|
||||
<div className='flex flex-col gap-2 sm:flex-row sm:items-center sm:justify-between'>
|
||||
<div>
|
||||
<p className='text-sm font-medium'>{t('Local models')}</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{t('Select models and apply to channel models list.')}
|
||||
</p>
|
||||
</div>
|
||||
<div className='relative sm:w-72'>
|
||||
<Search className='text-muted-foreground absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2' />
|
||||
<Input
|
||||
placeholder={t('Search models...')}
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
className='pl-9'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex flex-wrap gap-2'>
|
||||
<Button variant='outline' size='sm' onClick={selectAllFiltered}>
|
||||
{t('Select all (filtered)')}
|
||||
</Button>
|
||||
<Button variant='outline' size='sm' onClick={clearSelection}>
|
||||
{t('Clear selection')}
|
||||
</Button>
|
||||
<Button
|
||||
size='sm'
|
||||
onClick={() => void applySelection('append')}
|
||||
disabled={!selected.length}
|
||||
>
|
||||
{t('Append to channel')}
|
||||
</Button>
|
||||
<Button
|
||||
variant='secondary'
|
||||
size='sm'
|
||||
onClick={() => void applySelection('replace')}
|
||||
disabled={!selected.length}
|
||||
>
|
||||
{t('Replace channel models')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className='overflow-hidden rounded-md border'>
|
||||
<div className='max-h-[420px] overflow-y-auto'>
|
||||
{filteredModels.length === 0 ? (
|
||||
<div className='text-muted-foreground p-6 text-center text-sm'>
|
||||
{t('No models found.')}
|
||||
</div>
|
||||
) : (
|
||||
<div className='divide-y'>
|
||||
{filteredModels.map((m) => {
|
||||
const checked = selected.includes(m.id)
|
||||
return (
|
||||
<div
|
||||
key={m.id}
|
||||
className='flex items-center justify-between gap-3 p-3'
|
||||
>
|
||||
<div className='flex min-w-0 items-start gap-3'>
|
||||
<Checkbox
|
||||
checked={checked}
|
||||
onCheckedChange={(v) =>
|
||||
toggleSelected(m.id, !!v)
|
||||
}
|
||||
aria-label={`Select model ${m.id}`}
|
||||
/>
|
||||
<div className='min-w-0'>
|
||||
<div className='truncate font-mono text-sm'>
|
||||
{m.id}
|
||||
</div>
|
||||
<div className='text-muted-foreground flex flex-wrap gap-x-3 gap-y-1 text-xs'>
|
||||
<span>
|
||||
{t('Size:')} {formatBytes(m.size)}
|
||||
</span>
|
||||
{m.digest && (
|
||||
<span className='truncate'>
|
||||
{t('Digest:')} {String(m.digest)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
className='text-destructive hover:text-destructive'
|
||||
onClick={() => {
|
||||
setDeleteTarget(m.id)
|
||||
setDeleteOpen(true)
|
||||
}}
|
||||
disabled={!channelId}
|
||||
>
|
||||
<Trash2 className='h-4 w-4' />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant='outline' onClick={close}>
|
||||
{t('Close')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AlertDialog
|
||||
open={deleteOpen}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user