Preserve local user table for superDream-specific features while syncing user lifecycle, API key CRUD and usage queries through sub2api. Admin token handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB) handle key writes that admin endpoints do not expose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
181 lines
7.1 KiB
Python
181 lines
7.1 KiB
Python
"""Auth service: local user table + sub2api sync.
|
|
|
|
Registration and login keep two systems in sync:
|
|
- superDream local User row (owns password hash, superDream JWT)
|
|
- sub2api User (owns API keys, usage, balance)
|
|
|
|
On registration, sub2api user is created via admin API; then we login to sub2api
|
|
with the same password to obtain and store a refresh_token. Any sub2api failure
|
|
rolls back the local insert so the two systems never drift apart silently.
|
|
|
|
On login, after local password verification we also refresh the stored sub2api
|
|
tokens (by re-logging in to sub2api) so that subsequent proxied calls have a
|
|
valid refresh_token without ever persisting the plaintext password.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.exceptions import BadRequestError, UnauthorizedError
|
|
from app.core.security import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
create_reset_token,
|
|
decode_token,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from app.integrations.sub2api import admin as sub2api_admin
|
|
from app.integrations.sub2api import user as sub2api_user
|
|
from app.integrations.sub2api.client import Sub2APIError, Sub2APITransportError
|
|
from app.models import User
|
|
from app.services.sub2api_session import store_tokens
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _validate_password(password: str) -> None:
|
|
if len(password) < 6:
|
|
raise BadRequestError("密码至少 6 位")
|
|
has_letter = bool(re.search(r"[a-zA-Z]", password))
|
|
has_digit = bool(re.search(r"[0-9]", password))
|
|
if not (has_letter and has_digit):
|
|
raise BadRequestError("密码需同时包含字母和数字")
|
|
|
|
|
|
def _upstream_failure(action: str, exc: Exception) -> BadRequestError:
|
|
logger.error("sub2api %s failed: %s", action, exc)
|
|
if isinstance(exc, Sub2APIError):
|
|
return BadRequestError(f"上游同步失败:{exc.message or exc.reason or 'unknown'}")
|
|
return BadRequestError("上游服务不可用,请稍后重试")
|
|
|
|
|
|
class AuthService:
|
|
@staticmethod
|
|
async def register(db: AsyncSession, email: str, password: str) -> User:
|
|
_validate_password(password)
|
|
|
|
existing = await db.execute(select(User).where(User.email == email))
|
|
if existing.scalar_one_or_none():
|
|
raise BadRequestError("Email already registered")
|
|
|
|
# 1. Create the sub2api user first (so a conflict on their side halts us
|
|
# before we touch our DB).
|
|
try:
|
|
remote = await sub2api_admin.create_user(email=email, password=password)
|
|
except (Sub2APIError, Sub2APITransportError) as exc:
|
|
raise _upstream_failure("create_user", exc) from exc
|
|
|
|
remote_user_id = int(remote.get("id"))
|
|
|
|
# 2. Grab a token pair by logging in as the new user.
|
|
try:
|
|
tokens = await sub2api_user.login(email, password)
|
|
except (Sub2APIError, Sub2APITransportError) as exc:
|
|
# Roll back the remote user so re-registration works.
|
|
try:
|
|
await sub2api_admin.delete_user(remote_user_id)
|
|
except Exception: # pragma: no cover - best-effort cleanup
|
|
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
|
raise _upstream_failure("login_after_register", exc) from exc
|
|
|
|
# 3. Persist the local user with sub2api references.
|
|
user = User(
|
|
email=email,
|
|
password_hash=hash_password(password),
|
|
sub2api_user_id=remote_user_id,
|
|
)
|
|
db.add(user)
|
|
try:
|
|
await db.flush()
|
|
await store_tokens(db, user, tokens)
|
|
except Exception:
|
|
await db.rollback()
|
|
try:
|
|
await sub2api_admin.delete_user(remote_user_id)
|
|
except Exception: # pragma: no cover
|
|
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
|
raise
|
|
|
|
return user
|
|
|
|
@staticmethod
|
|
async def login(db: AsyncSession, email: str, password: str) -> dict:
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not verify_password(password, user.password_hash):
|
|
raise UnauthorizedError("Invalid email or password")
|
|
if user.status != "active":
|
|
raise UnauthorizedError("Account is disabled")
|
|
|
|
# Refresh stored sub2api tokens with the same password. If sub2api is
|
|
# down we still allow local login; proxied calls will surface 503 later.
|
|
try:
|
|
tokens = await sub2api_user.login(email, password)
|
|
if not user.sub2api_user_id and tokens.get("user", {}).get("id"):
|
|
user.sub2api_user_id = int(tokens["user"]["id"])
|
|
await store_tokens(db, user, tokens)
|
|
except Sub2APIError as exc:
|
|
logger.warning("sub2api login failed for %s: %s", email, exc)
|
|
except Sub2APITransportError as exc:
|
|
logger.warning("sub2api unreachable during login: %s", exc)
|
|
|
|
return {
|
|
"access_token": create_access_token(user.id),
|
|
"refresh_token": create_refresh_token(user.id),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
@staticmethod
|
|
async def refresh(db: AsyncSession, refresh_token: str) -> dict:
|
|
payload = decode_token(refresh_token)
|
|
if not payload or payload.get("type") != "refresh":
|
|
raise UnauthorizedError("Invalid refresh token")
|
|
|
|
user = await db.get(User, payload["sub"])
|
|
if not user or user.status != "active":
|
|
raise UnauthorizedError("User not found or disabled")
|
|
|
|
return {
|
|
"access_token": create_access_token(user.id),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
@staticmethod
|
|
async def forgot_password(db: AsyncSession, email: str) -> str:
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
return ""
|
|
return create_reset_token(user.id)
|
|
|
|
@staticmethod
|
|
async def reset_password(db: AsyncSession, token: str, new_password: str) -> None:
|
|
_validate_password(new_password)
|
|
|
|
payload = decode_token(token)
|
|
if not payload or payload.get("type") != "reset":
|
|
raise BadRequestError("Invalid or expired reset token")
|
|
|
|
user = await db.get(User, payload["sub"])
|
|
if not user:
|
|
raise BadRequestError("User not found")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
|
|
# Sync the new password to sub2api so subsequent proxied calls continue
|
|
# to work. If this fails the user can still log in locally but won't be
|
|
# able to mutate keys until the systems re-converge.
|
|
if user.sub2api_user_id:
|
|
try:
|
|
await sub2api_admin.update_user(user.sub2api_user_id, password=new_password)
|
|
tokens = await sub2api_user.login(user.email, new_password)
|
|
await store_tokens(db, user, tokens)
|
|
except (Sub2APIError, Sub2APITransportError) as exc:
|
|
logger.error("sub2api password sync failed for user %s: %s", user.id, exc)
|
|
|
|
await db.commit() |