integrate sub2api as upstream for auth/keys/usage via FastAPI BFF
Preserve local user table for superDream-specific features while syncing user lifecycle, API key CRUD and usage queries through sub2api. Admin token handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB) handle key writes that admin endpoints do not expose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,139 +1,92 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
"""Usage service: fully proxied to sub2api (user JWT for user-scoped endpoints)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select, func, cast, Date
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import UsageLog, ApiKey
|
||||
from app.core.exceptions import BadRequestError
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
)
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import ensure_access_token
|
||||
|
||||
|
||||
def _require_binding(user: User) -> int:
|
||||
if not user.sub2api_user_id:
|
||||
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
|
||||
return user.sub2api_user_id
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, Sub2APIReauthRequired):
|
||||
return HTTPException(status_code=401, detail="sub2api_reauth_required")
|
||||
if isinstance(exc, Sub2APIError):
|
||||
status = exc.http_status or 502
|
||||
if status < 400 or status >= 600:
|
||||
status = 502
|
||||
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
|
||||
return HTTPException(status_code=504, detail="upstream_timeout")
|
||||
|
||||
|
||||
class UsageService:
|
||||
@staticmethod
|
||||
async def list_logs(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.list_usage(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def summary(db: AsyncSession, user_id: str) -> dict:
|
||||
today_start = datetime.combine(date.today(), datetime.min.time())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Today
|
||||
today_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
|
||||
)).one()
|
||||
|
||||
# This month
|
||||
month_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
|
||||
)).one()
|
||||
|
||||
return {
|
||||
"today_tokens": int(today_row[0]),
|
||||
"today_cost": today_row[1],
|
||||
"month_tokens": int(month_row[0]),
|
||||
"month_cost": month_row[1],
|
||||
"total_requests": int(month_row[2]),
|
||||
}
|
||||
async def stats(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.usage_stats(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def daily(
|
||||
db: AsyncSession, user_id: str,
|
||||
start: Optional[date] = None, end: Optional[date] = None,
|
||||
) -> list:
|
||||
if not start:
|
||||
start = date.today() - timedelta(days=29)
|
||||
if not end:
|
||||
end = date.today()
|
||||
|
||||
day_col = cast(UsageLog.request_time, Date).label("day")
|
||||
result = await db.execute(
|
||||
select(
|
||||
day_col,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
UsageLog.user_id == user_id,
|
||||
cast(UsageLog.request_time, Date) >= start,
|
||||
cast(UsageLog.request_time, Date) <= end,
|
||||
)
|
||||
.group_by(day_col)
|
||||
.order_by(day_col)
|
||||
)
|
||||
return [
|
||||
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_stats(db: AsyncSession, user: User) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_stats(token)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_model(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.model,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.model)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_trend(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_trend(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_key(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.key_id,
|
||||
ApiKey.name,
|
||||
ApiKey.key_prefix,
|
||||
ApiKey.key_suffix,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.join(ApiKey, UsageLog.key_id == ApiKey.id)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{
|
||||
"key_id": row[0], "key_name": row[1] or "",
|
||||
"key_prefix": row[2], "key_suffix": row[3],
|
||||
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_models(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_models(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def logs(
|
||||
db: AsyncSession, user_id: str,
|
||||
page: int = 1, size: int = 20,
|
||||
model: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
) -> list:
|
||||
q = select(UsageLog).where(UsageLog.user_id == user_id)
|
||||
if model:
|
||||
q = q.where(UsageLog.model == model)
|
||||
if key_id:
|
||||
q = q.where(UsageLog.key_id == key_id)
|
||||
if start:
|
||||
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
|
||||
if end:
|
||||
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
|
||||
|
||||
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(q)
|
||||
return result.scalars().all()
|
||||
async def dashboard_api_keys_usage(
|
||||
db: AsyncSession, user: User, api_key_ids: list[int]
|
||||
) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_api_keys_usage(token, api_key_ids)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
Reference in New Issue
Block a user