integrate sub2api as upstream for auth/keys/usage via FastAPI BFF
Preserve local user table for superDream-specific features while syncing user lifecycle, API key CRUD and usage queries through sub2api. Admin token handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB) handle key writes that admin endpoints do not expose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,41 @@
|
||||
"""Auth service: local user table + sub2api sync.
|
||||
|
||||
Registration and login keep two systems in sync:
|
||||
- superDream local User row (owns password hash, superDream JWT)
|
||||
- sub2api User (owns API keys, usage, balance)
|
||||
|
||||
On registration, sub2api user is created via admin API; then we login to sub2api
|
||||
with the same password to obtain and store a refresh_token. Any sub2api failure
|
||||
rolls back the local insert so the two systems never drift apart silently.
|
||||
|
||||
On login, after local password verification we also refresh the stored sub2api
|
||||
tokens (by re-logging in to sub2api) so that subsequent proxied calls have a
|
||||
valid refresh_token without ever persisting the plaintext password.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.exceptions import BadRequestError, UnauthorizedError
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_reset_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from app.integrations.sub2api import admin as sub2api_admin
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import Sub2APIError, Sub2APITransportError
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import store_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_password(password: str) -> None:
|
||||
@@ -17,8 +47,14 @@ def _validate_password(password: str) -> None:
|
||||
raise BadRequestError("密码需同时包含字母和数字")
|
||||
|
||||
|
||||
class AuthService:
|
||||
def _upstream_failure(action: str, exc: Exception) -> BadRequestError:
|
||||
logger.error("sub2api %s failed: %s", action, exc)
|
||||
if isinstance(exc, Sub2APIError):
|
||||
return BadRequestError(f"上游同步失败:{exc.message or exc.reason or 'unknown'}")
|
||||
return BadRequestError("上游服务不可用,请稍后重试")
|
||||
|
||||
|
||||
class AuthService:
|
||||
@staticmethod
|
||||
async def register(db: AsyncSession, email: str, password: str) -> User:
|
||||
_validate_password(password)
|
||||
@@ -27,10 +63,44 @@ class AuthService:
|
||||
if existing.scalar_one_or_none():
|
||||
raise BadRequestError("Email already registered")
|
||||
|
||||
user = User(email=email, password_hash=hash_password(password))
|
||||
# 1. Create the sub2api user first (so a conflict on their side halts us
|
||||
# before we touch our DB).
|
||||
try:
|
||||
remote = await sub2api_admin.create_user(email=email, password=password)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _upstream_failure("create_user", exc) from exc
|
||||
|
||||
remote_user_id = int(remote.get("id"))
|
||||
|
||||
# 2. Grab a token pair by logging in as the new user.
|
||||
try:
|
||||
tokens = await sub2api_user.login(email, password)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
# Roll back the remote user so re-registration works.
|
||||
try:
|
||||
await sub2api_admin.delete_user(remote_user_id)
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
||||
raise _upstream_failure("login_after_register", exc) from exc
|
||||
|
||||
# 3. Persist the local user with sub2api references.
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
sub2api_user_id=remote_user_id,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
try:
|
||||
await db.flush()
|
||||
await store_tokens(db, user, tokens)
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
try:
|
||||
await sub2api_admin.delete_user(remote_user_id)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
||||
raise
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -42,6 +112,18 @@ class AuthService:
|
||||
if user.status != "active":
|
||||
raise UnauthorizedError("Account is disabled")
|
||||
|
||||
# Refresh stored sub2api tokens with the same password. If sub2api is
|
||||
# down we still allow local login; proxied calls will surface 503 later.
|
||||
try:
|
||||
tokens = await sub2api_user.login(email, password)
|
||||
if not user.sub2api_user_id and tokens.get("user", {}).get("id"):
|
||||
user.sub2api_user_id = int(tokens["user"]["id"])
|
||||
await store_tokens(db, user, tokens)
|
||||
except Sub2APIError as exc:
|
||||
logger.warning("sub2api login failed for %s: %s", email, exc)
|
||||
except Sub2APITransportError as exc:
|
||||
logger.warning("sub2api unreachable during login: %s", exc)
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"refresh_token": create_refresh_token(user.id),
|
||||
@@ -65,14 +147,10 @@ class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def forgot_password(db: AsyncSession, email: str) -> str:
|
||||
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
# Don't reveal whether email exists
|
||||
return ""
|
||||
|
||||
from app.core.security import create_reset_token
|
||||
return create_reset_token(user.id)
|
||||
|
||||
@staticmethod
|
||||
@@ -88,4 +166,16 @@ class AuthService:
|
||||
raise BadRequestError("User not found")
|
||||
|
||||
user.password_hash = hash_password(new_password)
|
||||
await db.commit()
|
||||
|
||||
# Sync the new password to sub2api so subsequent proxied calls continue
|
||||
# to work. If this fails the user can still log in locally but won't be
|
||||
# able to mutate keys until the systems re-converge.
|
||||
if user.sub2api_user_id:
|
||||
try:
|
||||
await sub2api_admin.update_user(user.sub2api_user_id, password=new_password)
|
||||
tokens = await sub2api_user.login(user.email, new_password)
|
||||
await store_tokens(db, user, tokens)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
logger.error("sub2api password sync failed for user %s: %s", user.id, exc)
|
||||
|
||||
await db.commit()
|
||||
@@ -1,84 +1,111 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
"""API Key service: proxied to sub2api.
|
||||
|
||||
from sqlalchemy import select, func
|
||||
- Reads use the admin API key to call ``/admin/users/:id/api-keys``.
|
||||
- Writes (create/update/delete) require the user's own sub2api JWT; we fetch
|
||||
one via ``ensure_access_token`` using the stored refresh token.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ApiKey
|
||||
from app.core.exceptions import BadRequestError, NotFoundError
|
||||
from app.core.exceptions import BadRequestError
|
||||
from app.integrations.sub2api import admin as sub2api_admin
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
)
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import ensure_access_token
|
||||
|
||||
MAX_KEYS_PER_USER = 5
|
||||
KEY_PREFIX = "sk-sd-"
|
||||
|
||||
def _require_sub2api_binding(user: User) -> int:
|
||||
if not user.sub2api_user_id:
|
||||
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
|
||||
return user.sub2api_user_id
|
||||
|
||||
|
||||
def _translate_upstream(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, Sub2APIReauthRequired):
|
||||
return HTTPException(status_code=401, detail="sub2api_reauth_required")
|
||||
if isinstance(exc, Sub2APIError):
|
||||
status = exc.http_status or 502
|
||||
if status < 400 or status >= 600:
|
||||
status = 502
|
||||
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
|
||||
return HTTPException(status_code=504, detail="upstream_timeout")
|
||||
|
||||
|
||||
class KeyService:
|
||||
@staticmethod
|
||||
async def list_keys(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
*,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> dict[str, Any]:
|
||||
uid = _require_sub2api_binding(user)
|
||||
try:
|
||||
return await sub2api_admin.list_user_api_keys(
|
||||
uid,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
def _generate_key() -> str:
|
||||
return KEY_PREFIX + secrets.token_urlsafe(36)
|
||||
async def create_key(db: AsyncSession, user: User, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.create_key(token, payload)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(raw_key: str) -> str:
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
async def update_key(
|
||||
db: AsyncSession, user: User, key_id: int, payload: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.update_key(token, key_id, payload)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def list_keys(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
async def delete_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.delete_key(token, key_id)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
|
||||
# Check limit
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
)
|
||||
count = count_result.scalar()
|
||||
if count >= MAX_KEYS_PER_USER:
|
||||
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
|
||||
|
||||
raw_key = KeyService._generate_key()
|
||||
key_hash = KeyService._hash_key(raw_key)
|
||||
|
||||
# prefix/suffix for masked display (after "sk-sd-")
|
||||
body = raw_key[len(KEY_PREFIX):]
|
||||
key_prefix = body[:4]
|
||||
key_suffix = body[-4:]
|
||||
|
||||
api_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
key_suffix=key_suffix,
|
||||
)
|
||||
db.add(api_key)
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"key": raw_key, # only returned once
|
||||
"key_prefix": key_prefix,
|
||||
"key_suffix": key_suffix,
|
||||
"created_at": api_key.created_at,
|
||||
}
|
||||
async def get_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.get_key(token, key_id)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
if not api_key:
|
||||
raise NotFoundError("Key not found")
|
||||
|
||||
api_key.status = "revoked"
|
||||
await db.commit()
|
||||
async def available_groups(db: AsyncSession, user: User) -> list[dict[str, Any]]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.list_available_groups(token)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
74
app/services/sub2api_session.py
Normal file
74
app/services/sub2api_session.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Helpers to obtain a fresh sub2api user access token from stored state.
|
||||
|
||||
The User row carries:
|
||||
- sub2api_access_token + sub2api_access_expires_at (cache)
|
||||
- sub2api_refresh_token_enc (Fernet-encrypted refresh_token)
|
||||
|
||||
``ensure_access_token`` returns a usable access_token, refreshing via sub2api
|
||||
``/auth/refresh`` when the cached one is expired. On refresh failure the caller
|
||||
receives ``Sub2APIReauthRequired`` and should surface 401 so the frontend can
|
||||
prompt re-login.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.crypto import decrypt_token, encrypt_token
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import Sub2APIError, Sub2APIReauthRequired
|
||||
from app.models import User
|
||||
|
||||
_CLOCK_SKEW = timedelta(seconds=60)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
def _is_access_fresh(user: User) -> bool:
|
||||
if not user.sub2api_access_token or not user.sub2api_access_expires_at:
|
||||
return False
|
||||
expires_at = user.sub2api_access_expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return expires_at - _CLOCK_SKEW > _utcnow()
|
||||
|
||||
|
||||
async def store_tokens(db: AsyncSession, user: User, token_response: dict) -> None:
|
||||
"""Persist a fresh token pair (from /auth/login or /auth/refresh)."""
|
||||
access = token_response.get("access_token") or ""
|
||||
refresh = token_response.get("refresh_token") or ""
|
||||
expires_in = int(token_response.get("expires_in") or 3600)
|
||||
|
||||
user.sub2api_access_token = access
|
||||
user.sub2api_access_expires_at = _utcnow() + timedelta(seconds=expires_in)
|
||||
if refresh:
|
||||
user.sub2api_refresh_token_enc = encrypt_token(refresh)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
async def ensure_access_token(db: AsyncSession, user: User) -> str:
|
||||
"""Return a usable sub2api access_token, refreshing if necessary."""
|
||||
if _is_access_fresh(user):
|
||||
return user.sub2api_access_token # type: ignore[return-value]
|
||||
|
||||
if not user.sub2api_refresh_token_enc:
|
||||
raise Sub2APIReauthRequired(401, "no refresh token stored", "NO_REFRESH_TOKEN")
|
||||
|
||||
try:
|
||||
refresh_plain = decrypt_token(user.sub2api_refresh_token_enc)
|
||||
except ValueError as exc:
|
||||
raise Sub2APIReauthRequired(401, str(exc), "DECRYPT_FAILED") from exc
|
||||
|
||||
try:
|
||||
token_response = await sub2api_user.refresh_tokens(refresh_plain)
|
||||
except Sub2APIError as exc:
|
||||
raise Sub2APIReauthRequired(
|
||||
exc.code, exc.message, exc.reason or "REFRESH_FAILED", exc.metadata
|
||||
) from exc
|
||||
|
||||
await store_tokens(db, user, token_response)
|
||||
return user.sub2api_access_token # type: ignore[return-value]
|
||||
@@ -1,139 +1,92 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
"""Usage service: fully proxied to sub2api (user JWT for user-scoped endpoints)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select, func, cast, Date
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import UsageLog, ApiKey
|
||||
from app.core.exceptions import BadRequestError
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
)
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import ensure_access_token
|
||||
|
||||
|
||||
def _require_binding(user: User) -> int:
|
||||
if not user.sub2api_user_id:
|
||||
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
|
||||
return user.sub2api_user_id
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, Sub2APIReauthRequired):
|
||||
return HTTPException(status_code=401, detail="sub2api_reauth_required")
|
||||
if isinstance(exc, Sub2APIError):
|
||||
status = exc.http_status or 502
|
||||
if status < 400 or status >= 600:
|
||||
status = 502
|
||||
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
|
||||
return HTTPException(status_code=504, detail="upstream_timeout")
|
||||
|
||||
|
||||
class UsageService:
|
||||
@staticmethod
|
||||
async def list_logs(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.list_usage(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def summary(db: AsyncSession, user_id: str) -> dict:
|
||||
today_start = datetime.combine(date.today(), datetime.min.time())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Today
|
||||
today_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
|
||||
)).one()
|
||||
|
||||
# This month
|
||||
month_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
|
||||
)).one()
|
||||
|
||||
return {
|
||||
"today_tokens": int(today_row[0]),
|
||||
"today_cost": today_row[1],
|
||||
"month_tokens": int(month_row[0]),
|
||||
"month_cost": month_row[1],
|
||||
"total_requests": int(month_row[2]),
|
||||
}
|
||||
async def stats(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.usage_stats(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def daily(
|
||||
db: AsyncSession, user_id: str,
|
||||
start: Optional[date] = None, end: Optional[date] = None,
|
||||
) -> list:
|
||||
if not start:
|
||||
start = date.today() - timedelta(days=29)
|
||||
if not end:
|
||||
end = date.today()
|
||||
|
||||
day_col = cast(UsageLog.request_time, Date).label("day")
|
||||
result = await db.execute(
|
||||
select(
|
||||
day_col,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
UsageLog.user_id == user_id,
|
||||
cast(UsageLog.request_time, Date) >= start,
|
||||
cast(UsageLog.request_time, Date) <= end,
|
||||
)
|
||||
.group_by(day_col)
|
||||
.order_by(day_col)
|
||||
)
|
||||
return [
|
||||
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_stats(db: AsyncSession, user: User) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_stats(token)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_model(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.model,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.model)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_trend(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_trend(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_key(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.key_id,
|
||||
ApiKey.name,
|
||||
ApiKey.key_prefix,
|
||||
ApiKey.key_suffix,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.join(ApiKey, UsageLog.key_id == ApiKey.id)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{
|
||||
"key_id": row[0], "key_name": row[1] or "",
|
||||
"key_prefix": row[2], "key_suffix": row[3],
|
||||
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_models(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_models(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def logs(
|
||||
db: AsyncSession, user_id: str,
|
||||
page: int = 1, size: int = 20,
|
||||
model: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
) -> list:
|
||||
q = select(UsageLog).where(UsageLog.user_id == user_id)
|
||||
if model:
|
||||
q = q.where(UsageLog.model == model)
|
||||
if key_id:
|
||||
q = q.where(UsageLog.key_id == key_id)
|
||||
if start:
|
||||
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
|
||||
if end:
|
||||
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
|
||||
|
||||
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(q)
|
||||
return result.scalars().all()
|
||||
async def dashboard_api_keys_usage(
|
||||
db: AsyncSession, user: User, api_key_ids: list[int]
|
||||
) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_api_keys_usage(token, api_key_ids)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
Reference in New Issue
Block a user