47cddcbeaf
- Add check_jwt_secret() that refuses default JWT secret when APP_ENV != development - Make CORS origins configurable via CORS_ORIGINS env var (comma-separated) - Replace hardcoded postgres credentials in docker-compose.yml with env var references - Add APP_ENV and cors_origins to config.py - Update .env.example with all required variables and documentation - Add tests for JWT startup guard and CORS configuration Closes leeworks-agents/SPARC#4 Closes leeworks-agents/SPARC#5 Closes leeworks-agents/SPARC#6 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
225 lines
5.6 KiB
Python
225 lines
5.6 KiB
Python
"""JWT authentication utilities for SPARC API."""
|
|
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from pydantic import BaseModel
|
|
|
|
from SPARC import config
|
|
from SPARC.database import DatabaseClient
|
|
|
|
# JWT Configuration
|
|
_DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
|
|
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
|
|
JWT_ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
|
|
def check_jwt_secret() -> None:
|
|
"""Refuse to start with the default JWT secret in non-development environments.
|
|
|
|
Raises:
|
|
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
|
|
"""
|
|
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
|
|
raise RuntimeError(
|
|
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
|
|
"Set a secure JWT_SECRET environment variable before running in non-development environments."
|
|
)
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
class TokenPayload(BaseModel):
|
|
"""JWT token payload."""
|
|
|
|
sub: str # user_id as string (JWT RFC 7519 requires sub to be a string)
|
|
email: str
|
|
role: str
|
|
exp: datetime
|
|
type: str # "access" or "refresh"
|
|
|
|
@property
|
|
def user_id(self) -> int:
|
|
"""Get user_id as integer."""
|
|
return int(self.sub)
|
|
|
|
|
|
class TokenResponse(BaseModel):
|
|
"""Token response model."""
|
|
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str = "bearer"
|
|
|
|
|
|
class UserResponse(BaseModel):
|
|
"""User response model."""
|
|
|
|
id: int
|
|
email: str
|
|
role: str
|
|
created_at: datetime
|
|
|
|
|
|
def create_access_token(user_id: int, email: str, role: str) -> str:
|
|
"""Create a JWT access token.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
email: User email
|
|
role: User role
|
|
|
|
Returns:
|
|
Encoded JWT token
|
|
"""
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"email": email,
|
|
"role": role,
|
|
"exp": expire,
|
|
"type": "access",
|
|
}
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(user_id: int, email: str, role: str) -> str:
|
|
"""Create a JWT refresh token.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
email: User email
|
|
role: User role
|
|
|
|
Returns:
|
|
Encoded JWT token
|
|
"""
|
|
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"email": email,
|
|
"role": role,
|
|
"exp": expire,
|
|
"type": "refresh",
|
|
}
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
def create_tokens(user_id: int, email: str, role: str) -> TokenResponse:
|
|
"""Create both access and refresh tokens.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
email: User email
|
|
role: User role
|
|
|
|
Returns:
|
|
TokenResponse with both tokens
|
|
"""
|
|
return TokenResponse(
|
|
access_token=create_access_token(user_id, email, role),
|
|
refresh_token=create_refresh_token(user_id, email, role),
|
|
)
|
|
|
|
|
|
def decode_token(token: str) -> Optional[TokenPayload]:
|
|
"""Decode and validate a JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
TokenPayload if valid, None otherwise
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
|
return TokenPayload(**payload)
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
|
|
|
|
def get_db_client() -> DatabaseClient:
|
|
"""Get database client for auth operations."""
|
|
client = DatabaseClient(config.database_url)
|
|
client.connect()
|
|
return client
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
) -> UserResponse:
|
|
"""Get the current authenticated user from JWT token.
|
|
|
|
Args:
|
|
credentials: Bearer token from request
|
|
|
|
Returns:
|
|
UserResponse with user details
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or expired
|
|
"""
|
|
token = credentials.credentials
|
|
payload = decode_token(token)
|
|
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
if payload.type != "access":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
db = get_db_client()
|
|
user = db.get_user_by_id(payload.user_id)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
role=user["role"],
|
|
created_at=user["created_at"],
|
|
)
|
|
|
|
|
|
async def get_current_admin(
|
|
current_user: UserResponse = Depends(get_current_user),
|
|
) -> UserResponse:
|
|
"""Require admin role for the current user.
|
|
|
|
Args:
|
|
current_user: Current authenticated user
|
|
|
|
Returns:
|
|
UserResponse if admin
|
|
|
|
Raises:
|
|
HTTPException: If user is not admin
|
|
"""
|
|
if current_user.role != "admin":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Admin access required",
|
|
)
|
|
return current_user
|