integrate sub2api as upstream for auth/keys/usage via FastAPI BFF

Preserve local user table for superDream-specific features while syncing
user lifecycle, API key CRUD and usage queries through sub2api. Admin token
handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB)
handle key writes that admin endpoints do not expose.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
xuyong
2026-04-17 21:23:08 +08:00
parent 20e842a60a
commit 35c0b7de16
30 changed files with 1707 additions and 803 deletions

View File

@@ -1,11 +1,41 @@
"""Auth service: local user table + sub2api sync.
Registration and login keep two systems in sync:
- superDream local User row (owns password hash, superDream JWT)
- sub2api User (owns API keys, usage, balance)
On registration, sub2api user is created via admin API; then we login to sub2api
with the same password to obtain and store a refresh_token. Any sub2api failure
rolls back the local insert so the two systems never drift apart silently.
On login, after local password verification we also refresh the stored sub2api
tokens (by re-logging in to sub2api) so that subsequent proxied calls have a
valid refresh_token without ever persisting the plaintext password.
"""
from __future__ import annotations
import logging
import re
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.exceptions import BadRequestError, UnauthorizedError
from app.core.security import (
create_access_token,
create_refresh_token,
create_reset_token,
decode_token,
hash_password,
verify_password,
)
from app.integrations.sub2api import admin as sub2api_admin
from app.integrations.sub2api import user as sub2api_user
from app.integrations.sub2api.client import Sub2APIError, Sub2APITransportError
from app.models import User
from app.services.sub2api_session import store_tokens
logger = logging.getLogger(__name__)
def _validate_password(password: str) -> None:
@@ -17,8 +47,14 @@ def _validate_password(password: str) -> None:
raise BadRequestError("密码需同时包含字母和数字")
class AuthService:
def _upstream_failure(action: str, exc: Exception) -> BadRequestError:
logger.error("sub2api %s failed: %s", action, exc)
if isinstance(exc, Sub2APIError):
return BadRequestError(f"上游同步失败:{exc.message or exc.reason or 'unknown'}")
return BadRequestError("上游服务不可用,请稍后重试")
class AuthService:
@staticmethod
async def register(db: AsyncSession, email: str, password: str) -> User:
_validate_password(password)
@@ -27,10 +63,44 @@ class AuthService:
if existing.scalar_one_or_none():
raise BadRequestError("Email already registered")
user = User(email=email, password_hash=hash_password(password))
# 1. Create the sub2api user first (so a conflict on their side halts us
# before we touch our DB).
try:
remote = await sub2api_admin.create_user(email=email, password=password)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _upstream_failure("create_user", exc) from exc
remote_user_id = int(remote.get("id"))
# 2. Grab a token pair by logging in as the new user.
try:
tokens = await sub2api_user.login(email, password)
except (Sub2APIError, Sub2APITransportError) as exc:
# Roll back the remote user so re-registration works.
try:
await sub2api_admin.delete_user(remote_user_id)
except Exception: # pragma: no cover - best-effort cleanup
logger.exception("failed to roll back sub2api user %s", remote_user_id)
raise _upstream_failure("login_after_register", exc) from exc
# 3. Persist the local user with sub2api references.
user = User(
email=email,
password_hash=hash_password(password),
sub2api_user_id=remote_user_id,
)
db.add(user)
await db.commit()
await db.refresh(user)
try:
await db.flush()
await store_tokens(db, user, tokens)
except Exception:
await db.rollback()
try:
await sub2api_admin.delete_user(remote_user_id)
except Exception: # pragma: no cover
logger.exception("failed to roll back sub2api user %s", remote_user_id)
raise
return user
@staticmethod
@@ -42,6 +112,18 @@ class AuthService:
if user.status != "active":
raise UnauthorizedError("Account is disabled")
# Refresh stored sub2api tokens with the same password. If sub2api is
# down we still allow local login; proxied calls will surface 503 later.
try:
tokens = await sub2api_user.login(email, password)
if not user.sub2api_user_id and tokens.get("user", {}).get("id"):
user.sub2api_user_id = int(tokens["user"]["id"])
await store_tokens(db, user, tokens)
except Sub2APIError as exc:
logger.warning("sub2api login failed for %s: %s", email, exc)
except Sub2APITransportError as exc:
logger.warning("sub2api unreachable during login: %s", exc)
return {
"access_token": create_access_token(user.id),
"refresh_token": create_refresh_token(user.id),
@@ -65,14 +147,10 @@ class AuthService:
@staticmethod
async def forgot_password(db: AsyncSession, email: str) -> str:
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user:
# Don't reveal whether email exists
return ""
from app.core.security import create_reset_token
return create_reset_token(user.id)
@staticmethod
@@ -88,4 +166,16 @@ class AuthService:
raise BadRequestError("User not found")
user.password_hash = hash_password(new_password)
await db.commit()
# Sync the new password to sub2api so subsequent proxied calls continue
# to work. If this fails the user can still log in locally but won't be
# able to mutate keys until the systems re-converge.
if user.sub2api_user_id:
try:
await sub2api_admin.update_user(user.sub2api_user_id, password=new_password)
tokens = await sub2api_user.login(user.email, new_password)
await store_tokens(db, user, tokens)
except (Sub2APIError, Sub2APITransportError) as exc:
logger.error("sub2api password sync failed for user %s: %s", user.id, exc)
await db.commit()

View File

@@ -1,84 +1,111 @@
import hashlib
import secrets
import uuid
"""API Key service: proxied to sub2api.
from sqlalchemy import select, func
- Reads use the admin API key to call ``/admin/users/:id/api-keys``.
- Writes (create/update/delete) require the user's own sub2api JWT; we fetch
one via ``ensure_access_token`` using the stored refresh token.
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ApiKey
from app.core.exceptions import BadRequestError, NotFoundError
from app.core.exceptions import BadRequestError
from app.integrations.sub2api import admin as sub2api_admin
from app.integrations.sub2api import user as sub2api_user
from app.integrations.sub2api.client import (
Sub2APIError,
Sub2APIReauthRequired,
Sub2APITransportError,
)
from app.models import User
from app.services.sub2api_session import ensure_access_token
MAX_KEYS_PER_USER = 5
KEY_PREFIX = "sk-sd-"
def _require_sub2api_binding(user: User) -> int:
if not user.sub2api_user_id:
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
return user.sub2api_user_id
def _translate_upstream(exc: Exception) -> HTTPException:
if isinstance(exc, Sub2APIReauthRequired):
return HTTPException(status_code=401, detail="sub2api_reauth_required")
if isinstance(exc, Sub2APIError):
status = exc.http_status or 502
if status < 400 or status >= 600:
status = 502
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
return HTTPException(status_code=504, detail="upstream_timeout")
class KeyService:
@staticmethod
async def list_keys(
db: AsyncSession,
user: User,
*,
page: int = 1,
page_size: int = 20,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> dict[str, Any]:
uid = _require_sub2api_binding(user)
try:
return await sub2api_admin.list_user_api_keys(
uid,
page=page,
page_size=page_size,
sort_by=sort_by,
sort_order=sort_order,
)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc
@staticmethod
def _generate_key() -> str:
return KEY_PREFIX + secrets.token_urlsafe(36)
async def create_key(db: AsyncSession, user: User, payload: dict[str, Any]) -> dict[str, Any]:
_require_sub2api_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.create_key(token, payload)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc
@staticmethod
def _hash_key(raw_key: str) -> str:
return hashlib.sha256(raw_key.encode()).hexdigest()
async def update_key(
db: AsyncSession, user: User, key_id: int, payload: dict[str, Any]
) -> dict[str, Any]:
_require_sub2api_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.update_key(token, key_id, payload)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc
@staticmethod
async def list_keys(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(ApiKey)
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
.order_by(ApiKey.created_at.desc())
)
return result.scalars().all()
async def delete_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
_require_sub2api_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.delete_key(token, key_id)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc
@staticmethod
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
# Check limit
count_result = await db.execute(
select(func.count()).select_from(ApiKey)
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
)
count = count_result.scalar()
if count >= MAX_KEYS_PER_USER:
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
raw_key = KeyService._generate_key()
key_hash = KeyService._hash_key(raw_key)
# prefix/suffix for masked display (after "sk-sd-")
body = raw_key[len(KEY_PREFIX):]
key_prefix = body[:4]
key_suffix = body[-4:]
api_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
key_suffix=key_suffix,
)
db.add(api_key)
await db.commit()
await db.refresh(api_key)
return {
"id": api_key.id,
"name": api_key.name,
"key": raw_key, # only returned once
"key_prefix": key_prefix,
"key_suffix": key_suffix,
"created_at": api_key.created_at,
}
async def get_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
_require_sub2api_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.get_key(token, key_id)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc
@staticmethod
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
result = await db.execute(
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
)
api_key = result.scalar_one_or_none()
if not api_key:
raise NotFoundError("Key not found")
api_key.status = "revoked"
await db.commit()
async def available_groups(db: AsyncSession, user: User) -> list[dict[str, Any]]:
_require_sub2api_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.list_available_groups(token)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate_upstream(exc) from exc

View File

@@ -0,0 +1,74 @@
"""Helpers to obtain a fresh sub2api user access token from stored state.
The User row carries:
- sub2api_access_token + sub2api_access_expires_at (cache)
- sub2api_refresh_token_enc (Fernet-encrypted refresh_token)
``ensure_access_token`` returns a usable access_token, refreshing via sub2api
``/auth/refresh`` when the cached one is expired. On refresh failure the caller
receives ``Sub2APIReauthRequired`` and should surface 401 so the frontend can
prompt re-login.
"""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.crypto import decrypt_token, encrypt_token
from app.integrations.sub2api import user as sub2api_user
from app.integrations.sub2api.client import Sub2APIError, Sub2APIReauthRequired
from app.models import User
_CLOCK_SKEW = timedelta(seconds=60)
def _utcnow() -> datetime:
return datetime.now(tz=timezone.utc)
def _is_access_fresh(user: User) -> bool:
if not user.sub2api_access_token or not user.sub2api_access_expires_at:
return False
expires_at = user.sub2api_access_expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return expires_at - _CLOCK_SKEW > _utcnow()
async def store_tokens(db: AsyncSession, user: User, token_response: dict) -> None:
"""Persist a fresh token pair (from /auth/login or /auth/refresh)."""
access = token_response.get("access_token") or ""
refresh = token_response.get("refresh_token") or ""
expires_in = int(token_response.get("expires_in") or 3600)
user.sub2api_access_token = access
user.sub2api_access_expires_at = _utcnow() + timedelta(seconds=expires_in)
if refresh:
user.sub2api_refresh_token_enc = encrypt_token(refresh)
await db.commit()
await db.refresh(user)
async def ensure_access_token(db: AsyncSession, user: User) -> str:
"""Return a usable sub2api access_token, refreshing if necessary."""
if _is_access_fresh(user):
return user.sub2api_access_token # type: ignore[return-value]
if not user.sub2api_refresh_token_enc:
raise Sub2APIReauthRequired(401, "no refresh token stored", "NO_REFRESH_TOKEN")
try:
refresh_plain = decrypt_token(user.sub2api_refresh_token_enc)
except ValueError as exc:
raise Sub2APIReauthRequired(401, str(exc), "DECRYPT_FAILED") from exc
try:
token_response = await sub2api_user.refresh_tokens(refresh_plain)
except Sub2APIError as exc:
raise Sub2APIReauthRequired(
exc.code, exc.message, exc.reason or "REFRESH_FAILED", exc.metadata
) from exc
await store_tokens(db, user, token_response)
return user.sub2api_access_token # type: ignore[return-value]

View File

@@ -1,139 +1,92 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional
"""Usage service: fully proxied to sub2api (user JWT for user-scoped endpoints)."""
from __future__ import annotations
from sqlalchemy import select, func, cast, Date
from typing import Any
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import UsageLog, ApiKey
from app.core.exceptions import BadRequestError
from app.integrations.sub2api import user as sub2api_user
from app.integrations.sub2api.client import (
Sub2APIError,
Sub2APIReauthRequired,
Sub2APITransportError,
)
from app.models import User
from app.services.sub2api_session import ensure_access_token
def _require_binding(user: User) -> int:
if not user.sub2api_user_id:
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
return user.sub2api_user_id
def _translate(exc: Exception) -> HTTPException:
if isinstance(exc, Sub2APIReauthRequired):
return HTTPException(status_code=401, detail="sub2api_reauth_required")
if isinstance(exc, Sub2APIError):
status = exc.http_status or 502
if status < 400 or status >= 600:
status = 502
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
return HTTPException(status_code=504, detail="upstream_timeout")
class UsageService:
@staticmethod
async def list_logs(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.list_usage(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def summary(db: AsyncSession, user_id: str) -> dict:
today_start = datetime.combine(date.today(), datetime.min.time())
month_start = today_start.replace(day=1)
# Today
today_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
)).one()
# This month
month_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
)).one()
return {
"today_tokens": int(today_row[0]),
"today_cost": today_row[1],
"month_tokens": int(month_row[0]),
"month_cost": month_row[1],
"total_requests": int(month_row[2]),
}
async def stats(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.usage_stats(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def daily(
db: AsyncSession, user_id: str,
start: Optional[date] = None, end: Optional[date] = None,
) -> list:
if not start:
start = date.today() - timedelta(days=29)
if not end:
end = date.today()
day_col = cast(UsageLog.request_time, Date).label("day")
result = await db.execute(
select(
day_col,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(
UsageLog.user_id == user_id,
cast(UsageLog.request_time, Date) >= start,
cast(UsageLog.request_time, Date) <= end,
)
.group_by(day_col)
.order_by(day_col)
)
return [
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
async def dashboard_stats(db: AsyncSession, user: User) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_stats(token)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def by_model(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.model,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.model)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
async def dashboard_trend(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_trend(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def by_key(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.key_id,
ApiKey.name,
ApiKey.key_prefix,
ApiKey.key_suffix,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.join(ApiKey, UsageLog.key_id == ApiKey.id)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{
"key_id": row[0], "key_name": row[1] or "",
"key_prefix": row[2], "key_suffix": row[3],
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
}
for row in result.all()
]
async def dashboard_models(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_models(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def logs(
db: AsyncSession, user_id: str,
page: int = 1, size: int = 20,
model: Optional[str] = None,
key_id: Optional[str] = None,
start: Optional[date] = None,
end: Optional[date] = None,
) -> list:
q = select(UsageLog).where(UsageLog.user_id == user_id)
if model:
q = q.where(UsageLog.model == model)
if key_id:
q = q.where(UsageLog.key_id == key_id)
if start:
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
if end:
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
result = await db.execute(q)
return result.scalars().all()
async def dashboard_api_keys_usage(
db: AsyncSession, user: User, api_key_ids: list[int]
) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_api_keys_usage(token, api_key_ids)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc