first commit
This commit is contained in:
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
91
app/services/auth_service.py
Normal file
91
app/services/auth_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.exceptions import BadRequestError, UnauthorizedError
|
||||
|
||||
|
||||
def _validate_password(password: str) -> None:
|
||||
if len(password) < 6:
|
||||
raise BadRequestError("密码至少 6 位")
|
||||
has_letter = bool(re.search(r"[a-zA-Z]", password))
|
||||
has_digit = bool(re.search(r"[0-9]", password))
|
||||
if not (has_letter and has_digit):
|
||||
raise BadRequestError("密码需同时包含字母和数字")
|
||||
|
||||
|
||||
class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def register(db: AsyncSession, email: str, password: str) -> User:
|
||||
_validate_password(password)
|
||||
|
||||
existing = await db.execute(select(User).where(User.email == email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise BadRequestError("Email already registered")
|
||||
|
||||
user = User(email=email, password_hash=hash_password(password))
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def login(db: AsyncSession, email: str, password: str) -> dict:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
raise UnauthorizedError("Invalid email or password")
|
||||
if user.status != "active":
|
||||
raise UnauthorizedError("Account is disabled")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"refresh_token": create_refresh_token(user.id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def refresh(db: AsyncSession, refresh_token: str) -> dict:
|
||||
payload = decode_token(refresh_token)
|
||||
if not payload or payload.get("type") != "refresh":
|
||||
raise UnauthorizedError("Invalid refresh token")
|
||||
|
||||
user = await db.get(User, payload["sub"])
|
||||
if not user or user.status != "active":
|
||||
raise UnauthorizedError("User not found or disabled")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def forgot_password(db: AsyncSession, email: str) -> str:
|
||||
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
# Don't reveal whether email exists
|
||||
return ""
|
||||
|
||||
from app.core.security import create_reset_token
|
||||
return create_reset_token(user.id)
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(db: AsyncSession, token: str, new_password: str) -> None:
|
||||
_validate_password(new_password)
|
||||
|
||||
payload = decode_token(token)
|
||||
if not payload or payload.get("type") != "reset":
|
||||
raise BadRequestError("Invalid or expired reset token")
|
||||
|
||||
user = await db.get(User, payload["sub"])
|
||||
if not user:
|
||||
raise BadRequestError("User not found")
|
||||
|
||||
user.password_hash = hash_password(new_password)
|
||||
await db.commit()
|
||||
24
app/services/example_service.py
Normal file
24
app/services/example_service.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class ExampleService:
|
||||
def __init__(self):
|
||||
self._store: Dict[str, dict] = {}
|
||||
|
||||
async def list_all(self) -> List[dict]:
|
||||
return list(self._store.values())
|
||||
|
||||
async def get_by_id(self, example_id: str) -> Optional[dict]:
|
||||
return self._store.get(example_id)
|
||||
|
||||
async def create(self, data: dict) -> dict:
|
||||
example_id = str(uuid.uuid4())
|
||||
item = {
|
||||
"id": example_id,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
**data,
|
||||
}
|
||||
self._store[example_id] = item
|
||||
return item
|
||||
84
app/services/key_service.py
Normal file
84
app/services/key_service.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ApiKey
|
||||
from app.core.exceptions import BadRequestError, NotFoundError
|
||||
|
||||
MAX_KEYS_PER_USER = 5
|
||||
KEY_PREFIX = "sk-sd-"
|
||||
|
||||
|
||||
class KeyService:
|
||||
|
||||
@staticmethod
|
||||
def _generate_key() -> str:
|
||||
return KEY_PREFIX + secrets.token_urlsafe(36)
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(raw_key: str) -> str:
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
async def list_keys(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
|
||||
# Check limit
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
)
|
||||
count = count_result.scalar()
|
||||
if count >= MAX_KEYS_PER_USER:
|
||||
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
|
||||
|
||||
raw_key = KeyService._generate_key()
|
||||
key_hash = KeyService._hash_key(raw_key)
|
||||
|
||||
# prefix/suffix for masked display (after "sk-sd-")
|
||||
body = raw_key[len(KEY_PREFIX):]
|
||||
key_prefix = body[:4]
|
||||
key_suffix = body[-4:]
|
||||
|
||||
api_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
key_suffix=key_suffix,
|
||||
)
|
||||
db.add(api_key)
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"key": raw_key, # only returned once
|
||||
"key_prefix": key_prefix,
|
||||
"key_suffix": key_suffix,
|
||||
"created_at": api_key.created_at,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
if not api_key:
|
||||
raise NotFoundError("Key not found")
|
||||
|
||||
api_key.status = "revoked"
|
||||
await db.commit()
|
||||
16
app/services/model_service.py
Normal file
16
app/services/model_service.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ModelPricing
|
||||
|
||||
|
||||
class ModelService:
|
||||
|
||||
@staticmethod
|
||||
async def list_models(db: AsyncSession) -> list:
|
||||
result = await db.execute(
|
||||
select(ModelPricing)
|
||||
.where(ModelPricing.status == "available")
|
||||
.order_by(ModelPricing.provider, ModelPricing.model_name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
139
app/services/usage_service.py
Normal file
139
app/services/usage_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, cast, Date
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import UsageLog, ApiKey
|
||||
|
||||
|
||||
class UsageService:
|
||||
|
||||
@staticmethod
|
||||
async def summary(db: AsyncSession, user_id: str) -> dict:
|
||||
today_start = datetime.combine(date.today(), datetime.min.time())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Today
|
||||
today_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
|
||||
)).one()
|
||||
|
||||
# This month
|
||||
month_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
|
||||
)).one()
|
||||
|
||||
return {
|
||||
"today_tokens": int(today_row[0]),
|
||||
"today_cost": today_row[1],
|
||||
"month_tokens": int(month_row[0]),
|
||||
"month_cost": month_row[1],
|
||||
"total_requests": int(month_row[2]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def daily(
|
||||
db: AsyncSession, user_id: str,
|
||||
start: Optional[date] = None, end: Optional[date] = None,
|
||||
) -> list:
|
||||
if not start:
|
||||
start = date.today() - timedelta(days=29)
|
||||
if not end:
|
||||
end = date.today()
|
||||
|
||||
day_col = cast(UsageLog.request_time, Date).label("day")
|
||||
result = await db.execute(
|
||||
select(
|
||||
day_col,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
UsageLog.user_id == user_id,
|
||||
cast(UsageLog.request_time, Date) >= start,
|
||||
cast(UsageLog.request_time, Date) <= end,
|
||||
)
|
||||
.group_by(day_col)
|
||||
.order_by(day_col)
|
||||
)
|
||||
return [
|
||||
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def by_model(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.model,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.model)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def by_key(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.key_id,
|
||||
ApiKey.name,
|
||||
ApiKey.key_prefix,
|
||||
ApiKey.key_suffix,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.join(ApiKey, UsageLog.key_id == ApiKey.id)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{
|
||||
"key_id": row[0], "key_name": row[1] or "",
|
||||
"key_prefix": row[2], "key_suffix": row[3],
|
||||
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def logs(
|
||||
db: AsyncSession, user_id: str,
|
||||
page: int = 1, size: int = 20,
|
||||
model: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
) -> list:
|
||||
q = select(UsageLog).where(UsageLog.user_id == user_id)
|
||||
if model:
|
||||
q = q.where(UsageLog.model == model)
|
||||
if key_id:
|
||||
q = q.where(UsageLog.key_id == key_id)
|
||||
if start:
|
||||
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
|
||||
if end:
|
||||
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
|
||||
|
||||
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(q)
|
||||
return result.scalars().all()
|
||||
60
app/services/wallet_service.py
Normal file
60
app/services/wallet_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User, Transaction
|
||||
from app.core.exceptions import BadRequestError
|
||||
|
||||
# MVP: hardcoded redeem codes → amount mapping
|
||||
REDEEM_CODES = {
|
||||
"SUPERDREAM10": Decimal("10"),
|
||||
"SUPERDREAM50": Decimal("50"),
|
||||
"SUPERDREAM100": Decimal("100"),
|
||||
}
|
||||
|
||||
|
||||
class WalletService:
|
||||
|
||||
@staticmethod
|
||||
async def get_balance(db: AsyncSession, user_id: str) -> Decimal:
|
||||
user = await db.get(User, user_id)
|
||||
return user.balance
|
||||
|
||||
@staticmethod
|
||||
async def redeem_code(db: AsyncSession, user_id: str, code: str) -> Transaction:
|
||||
amount = REDEEM_CODES.get(code.upper())
|
||||
if not amount:
|
||||
raise BadRequestError("无效的兑换码")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
user.balance += amount
|
||||
new_balance = user.balance
|
||||
|
||||
txn = Transaction(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
type="topup",
|
||||
amount=amount,
|
||||
balance_after=new_balance,
|
||||
reference_id=f"redeem:{code.upper()}",
|
||||
)
|
||||
db.add(txn)
|
||||
await db.commit()
|
||||
await db.refresh(txn)
|
||||
return txn
|
||||
|
||||
@staticmethod
|
||||
async def list_transactions(
|
||||
db: AsyncSession, user_id: str, page: int = 1, size: int = 20
|
||||
) -> list:
|
||||
offset = (page - 1) * size
|
||||
result = await db.execute(
|
||||
select(Transaction)
|
||||
.where(Transaction.user_id == user_id)
|
||||
.order_by(Transaction.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)
|
||||
return result.scalars().all()
|
||||
Reference in New Issue
Block a user