Files
SPARC/SPARC/auth.py
T
agent-company d366443b38 refactor(db): use shared pooled DatabaseClient singleton instead of per-call instances
- Replace get_db_client() creating new DatabaseClient on every call with a
  module-level singleton initialized once at startup via init_db_client()
- Add init_db_client() and close_db_client() lifecycle functions called
  from FastAPI lifespan handler
- Migrate all DatabaseClient methods from legacy self.connect()/self.conn
  to pooled self.get_conn() context manager for thread-safe connection reuse
- Pool is properly torn down on application shutdown

Closes leeworks-agents/SPARC#7

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:03:56 +00:00

250 lines
6.3 KiB
Python

"""JWT authentication utilities for SPARC API."""
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from SPARC import config
from SPARC.database import DatabaseClient
# JWT Configuration
_DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
def check_jwt_secret() -> None:
"""Refuse to start with the default JWT secret in non-development environments.
Raises:
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
"""
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
raise RuntimeError(
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer()
class TokenPayload(BaseModel):
"""JWT token payload."""
sub: str # user_id as string (JWT RFC 7519 requires sub to be a string)
email: str
role: str
exp: datetime
type: str # "access" or "refresh"
@property
def user_id(self) -> int:
"""Get user_id as integer."""
return int(self.sub)
class TokenResponse(BaseModel):
"""Token response model."""
access_token: str
refresh_token: str
token_type: str = "bearer"
class UserResponse(BaseModel):
"""User response model."""
id: int
email: str
role: str
created_at: datetime
def create_access_token(user_id: int, email: str, role: str) -> str:
"""Create a JWT access token.
Args:
user_id: User ID
email: User email
role: User role
Returns:
Encoded JWT token
"""
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": str(user_id),
"email": email,
"role": role,
"exp": expire,
"type": "access",
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
def create_refresh_token(user_id: int, email: str, role: str) -> str:
"""Create a JWT refresh token.
Args:
user_id: User ID
email: User email
role: User role
Returns:
Encoded JWT token
"""
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
payload = {
"sub": str(user_id),
"email": email,
"role": role,
"exp": expire,
"type": "refresh",
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
def create_tokens(user_id: int, email: str, role: str) -> TokenResponse:
"""Create both access and refresh tokens.
Args:
user_id: User ID
email: User email
role: User role
Returns:
TokenResponse with both tokens
"""
return TokenResponse(
access_token=create_access_token(user_id, email, role),
refresh_token=create_refresh_token(user_id, email, role),
)
def decode_token(token: str) -> Optional[TokenPayload]:
"""Decode and validate a JWT token.
Args:
token: JWT token string
Returns:
TokenPayload if valid, None otherwise
"""
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient:
"""Get the shared pooled database client for auth operations.
Returns the module-level singleton DatabaseClient. If not yet initialized
(e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token.
Args:
credentials: Bearer token from request
Returns:
UserResponse with user details
Raises:
HTTPException: If token is invalid or expired
"""
token = credentials.credentials
payload = decode_token(token)
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
if payload.type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
headers={"WWW-Authenticate": "Bearer"},
)
db = get_db_client()
user = db.get_user_by_id(payload.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
async def get_current_admin(
current_user: UserResponse = Depends(get_current_user),
) -> UserResponse:
"""Require admin role for the current user.
Args:
current_user: Current authenticated user
Returns:
UserResponse if admin
Raises:
HTTPException: If user is not admin
"""
if current_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required",
)
return current_user