integrate sub2api as upstream for auth/keys/usage via FastAPI BFF

Preserve local user table for superDream-specific features while syncing
user lifecycle, API key CRUD and usage queries through sub2api. Admin token
handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB)
handle key writes that admin endpoints do not expose.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
xuyong
2026-04-17 21:23:08 +08:00
parent 20e842a60a
commit 35c0b7de16
30 changed files with 1707 additions and 803 deletions

View File

@@ -1,139 +1,92 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional
"""Usage service: fully proxied to sub2api (user JWT for user-scoped endpoints)."""
from __future__ import annotations
from sqlalchemy import select, func, cast, Date
from typing import Any
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import UsageLog, ApiKey
from app.core.exceptions import BadRequestError
from app.integrations.sub2api import user as sub2api_user
from app.integrations.sub2api.client import (
Sub2APIError,
Sub2APIReauthRequired,
Sub2APITransportError,
)
from app.models import User
from app.services.sub2api_session import ensure_access_token
def _require_binding(user: User) -> int:
if not user.sub2api_user_id:
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
return user.sub2api_user_id
def _translate(exc: Exception) -> HTTPException:
if isinstance(exc, Sub2APIReauthRequired):
return HTTPException(status_code=401, detail="sub2api_reauth_required")
if isinstance(exc, Sub2APIError):
status = exc.http_status or 502
if status < 400 or status >= 600:
status = 502
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
return HTTPException(status_code=504, detail="upstream_timeout")
class UsageService:
@staticmethod
async def list_logs(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.list_usage(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def summary(db: AsyncSession, user_id: str) -> dict:
today_start = datetime.combine(date.today(), datetime.min.time())
month_start = today_start.replace(day=1)
# Today
today_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
)).one()
# This month
month_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
)).one()
return {
"today_tokens": int(today_row[0]),
"today_cost": today_row[1],
"month_tokens": int(month_row[0]),
"month_cost": month_row[1],
"total_requests": int(month_row[2]),
}
async def stats(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.usage_stats(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def daily(
db: AsyncSession, user_id: str,
start: Optional[date] = None, end: Optional[date] = None,
) -> list:
if not start:
start = date.today() - timedelta(days=29)
if not end:
end = date.today()
day_col = cast(UsageLog.request_time, Date).label("day")
result = await db.execute(
select(
day_col,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(
UsageLog.user_id == user_id,
cast(UsageLog.request_time, Date) >= start,
cast(UsageLog.request_time, Date) <= end,
)
.group_by(day_col)
.order_by(day_col)
)
return [
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
async def dashboard_stats(db: AsyncSession, user: User) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_stats(token)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def by_model(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.model,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.model)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
async def dashboard_trend(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_trend(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def by_key(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.key_id,
ApiKey.name,
ApiKey.key_prefix,
ApiKey.key_suffix,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.join(ApiKey, UsageLog.key_id == ApiKey.id)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{
"key_id": row[0], "key_name": row[1] or "",
"key_prefix": row[2], "key_suffix": row[3],
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
}
for row in result.all()
]
async def dashboard_models(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_models(token, **params)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc
@staticmethod
async def logs(
db: AsyncSession, user_id: str,
page: int = 1, size: int = 20,
model: Optional[str] = None,
key_id: Optional[str] = None,
start: Optional[date] = None,
end: Optional[date] = None,
) -> list:
q = select(UsageLog).where(UsageLog.user_id == user_id)
if model:
q = q.where(UsageLog.model == model)
if key_id:
q = q.where(UsageLog.key_id == key_id)
if start:
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
if end:
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
result = await db.execute(q)
return result.scalars().all()
async def dashboard_api_keys_usage(
db: AsyncSession, user: User, api_key_ids: list[int]
) -> dict[str, Any]:
_require_binding(user)
try:
token = await ensure_access_token(db, user)
return await sub2api_user.dashboard_api_keys_usage(token, api_key_ids)
except (Sub2APIError, Sub2APITransportError) as exc:
raise _translate(exc) from exc