diff --git a/SPARC/auth.py b/SPARC/auth.py new file mode 100644 index 0000000..285054a --- /dev/null +++ b/SPARC/auth.py @@ -0,0 +1,205 @@ +"""JWT authentication utilities for SPARC API.""" + +import os +from datetime import datetime, timedelta, timezone +from typing import Optional + +import jwt +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel + +from SPARC import config +from SPARC.database import DatabaseClient + +# JWT Configuration +JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production") +JWT_ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 7 + +security = HTTPBearer() + + +class TokenPayload(BaseModel): + """JWT token payload.""" + + sub: int # user_id + email: str + role: str + exp: datetime + type: str # "access" or "refresh" + + +class TokenResponse(BaseModel): + """Token response model.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class UserResponse(BaseModel): + """User response model.""" + + id: int + email: str + role: str + created_at: datetime + + +def create_access_token(user_id: int, email: str, role: str) -> str: + """Create a JWT access token. + + Args: + user_id: User ID + email: User email + role: User role + + Returns: + Encoded JWT token + """ + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + payload = { + "sub": user_id, + "email": email, + "role": role, + "exp": expire, + "type": "access", + } + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +def create_refresh_token(user_id: int, email: str, role: str) -> str: + """Create a JWT refresh token. + + Args: + user_id: User ID + email: User email + role: User role + + Returns: + Encoded JWT token + """ + expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + payload = { + "sub": user_id, + "email": email, + "role": role, + "exp": expire, + "type": "refresh", + } + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +def create_tokens(user_id: int, email: str, role: str) -> TokenResponse: + """Create both access and refresh tokens. + + Args: + user_id: User ID + email: User email + role: User role + + Returns: + TokenResponse with both tokens + """ + return TokenResponse( + access_token=create_access_token(user_id, email, role), + refresh_token=create_refresh_token(user_id, email, role), + ) + + +def decode_token(token: str) -> Optional[TokenPayload]: + """Decode and validate a JWT token. + + Args: + token: JWT token string + + Returns: + TokenPayload if valid, None otherwise + """ + try: + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + return TokenPayload(**payload) + except jwt.ExpiredSignatureError: + return None + except jwt.InvalidTokenError: + return None + + +def get_db_client() -> DatabaseClient: + """Get database client for auth operations.""" + client = DatabaseClient(config.database_url) + client.connect() + return client + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> UserResponse: + """Get the current authenticated user from JWT token. + + Args: + credentials: Bearer token from request + + Returns: + UserResponse with user details + + Raises: + HTTPException: If token is invalid or expired + """ + token = credentials.credentials + payload = decode_token(token) + + if not payload: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if payload.type != "access": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + headers={"WWW-Authenticate": "Bearer"}, + ) + + db = get_db_client() + user = db.get_user_by_id(payload.sub) + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return UserResponse( + id=user["id"], + email=user["email"], + role=user["role"], + created_at=user["created_at"], + ) + + +async def get_current_admin( + current_user: UserResponse = Depends(get_current_user), +) -> UserResponse: + """Require admin role for the current user. + + Args: + current_user: Current authenticated user + + Returns: + UserResponse if admin + + Raises: + HTTPException: If user is not admin + """ + if current_user.role != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin access required", + ) + return current_user