Preserve local user table for superDream-specific features while syncing user lifecycle, API key CRUD and usage queries through sub2api. Admin token handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB) handle key writes that admin endpoints do not expose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
"""Helpers to obtain a fresh sub2api user access token from stored state.
|
|
|
|
The User row carries:
|
|
- sub2api_access_token + sub2api_access_expires_at (cache)
|
|
- sub2api_refresh_token_enc (Fernet-encrypted refresh_token)
|
|
|
|
``ensure_access_token`` returns a usable access_token, refreshing via sub2api
|
|
``/auth/refresh`` when the cached one is expired. On refresh failure the caller
|
|
receives ``Sub2APIReauthRequired`` and should surface 401 so the frontend can
|
|
prompt re-login.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.crypto import decrypt_token, encrypt_token
|
|
from app.integrations.sub2api import user as sub2api_user
|
|
from app.integrations.sub2api.client import Sub2APIError, Sub2APIReauthRequired
|
|
from app.models import User
|
|
|
|
_CLOCK_SKEW = timedelta(seconds=60)
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
return datetime.now(tz=timezone.utc)
|
|
|
|
|
|
def _is_access_fresh(user: User) -> bool:
|
|
if not user.sub2api_access_token or not user.sub2api_access_expires_at:
|
|
return False
|
|
expires_at = user.sub2api_access_expires_at
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
return expires_at - _CLOCK_SKEW > _utcnow()
|
|
|
|
|
|
async def store_tokens(db: AsyncSession, user: User, token_response: dict) -> None:
|
|
"""Persist a fresh token pair (from /auth/login or /auth/refresh)."""
|
|
access = token_response.get("access_token") or ""
|
|
refresh = token_response.get("refresh_token") or ""
|
|
expires_in = int(token_response.get("expires_in") or 3600)
|
|
|
|
user.sub2api_access_token = access
|
|
user.sub2api_access_expires_at = _utcnow() + timedelta(seconds=expires_in)
|
|
if refresh:
|
|
user.sub2api_refresh_token_enc = encrypt_token(refresh)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
|
|
async def ensure_access_token(db: AsyncSession, user: User) -> str:
|
|
"""Return a usable sub2api access_token, refreshing if necessary."""
|
|
if _is_access_fresh(user):
|
|
return user.sub2api_access_token # type: ignore[return-value]
|
|
|
|
if not user.sub2api_refresh_token_enc:
|
|
raise Sub2APIReauthRequired(401, "no refresh token stored", "NO_REFRESH_TOKEN")
|
|
|
|
try:
|
|
refresh_plain = decrypt_token(user.sub2api_refresh_token_enc)
|
|
except ValueError as exc:
|
|
raise Sub2APIReauthRequired(401, str(exc), "DECRYPT_FAILED") from exc
|
|
|
|
try:
|
|
token_response = await sub2api_user.refresh_tokens(refresh_plain)
|
|
except Sub2APIError as exc:
|
|
raise Sub2APIReauthRequired(
|
|
exc.code, exc.message, exc.reason or "REFRESH_FAILED", exc.metadata
|
|
) from exc
|
|
|
|
await store_tokens(db, user, token_response)
|
|
return user.sub2api_access_token # type: ignore[return-value] |