From e2d750146c2ec07a7d18d1c1bdf037d78258c171 Mon Sep 17 00:00:00 2001 From: agent-company Date: Thu, 26 Mar 2026 04:08:22 +0000 Subject: [PATCH] feat(auth): add rate limiting to login and register endpoints - Add slowapi rate limiter: 10 req/min for /auth/login, 5 req/min for /auth/register - Return HTTP 429 with Retry-After header when limit is exceeded - Add slowapi to requirements.txt - Add 4 passing tests for rate limit behavior Closes leeworks-agents/SPARC#9 Co-Authored-By: Claude Opus 4.6 (1M context) --- SPARC/api.py | 34 +++++++++++--- requirements.txt | 1 + tests/test_rate_limit.py | 97 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 tests/test_rate_limit.py diff --git a/SPARC/api.py b/SPARC/api.py index 482caab..fef8a8f 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -7,9 +7,13 @@ from contextlib import asynccontextmanager from datetime import datetime from typing import Annotated, List -from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from pydantic import BaseModel, EmailStr, Field +from slowapi import Limiter +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address from SPARC import config from SPARC.analyzer import CompanyAnalyzer @@ -164,6 +168,22 @@ app = FastAPI( root_path=config.root_path, ) +# Rate limiter (in-memory storage, suitable for single-instance deployments) +limiter = Limiter(key_func=get_remote_address) +app.state.limiter = limiter + + +@app.exception_handler(RateLimitExceeded) +async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + """Return 429 with Retry-After header when rate limit is exceeded.""" + retry_after = getattr(exc, "retry_after", 60) + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded. Please try again later."}, + headers={"Retry-After": str(retry_after)}, + ) + + # Add CORS middleware for React frontend app.add_middleware( CORSMiddleware, @@ -178,7 +198,8 @@ app.add_middleware( @app.post("/auth/register", response_model=UserResponse, tags=["Auth"]) -async def register(request: RegisterRequest): +@limiter.limit("5/minute") +async def register(request: Request, body: RegisterRequest): """Register a new user. The first registered user automatically becomes an admin. @@ -190,8 +211,8 @@ async def register(request: RegisterRequest): role = "admin" if user_count == 0 else "user" user = db.create_user( - email=request.email, - password=request.password, + email=body.email, + password=body.password, role=role, ) @@ -210,11 +231,12 @@ async def register(request: RegisterRequest): @app.post("/auth/login", response_model=TokenResponse, tags=["Auth"]) -async def login(request: LoginRequest): +@limiter.limit("10/minute") +async def login(request: Request, body: LoginRequest): """Authenticate user and return JWT tokens.""" db = get_db_client() - user = db.authenticate_user(request.email, request.password) + user = db.authenticate_user(body.email, body.password) if not user: raise HTTPException( diff --git a/requirements.txt b/requirements.txt index 7e87235..e854576 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ numpy pandas bcrypt PyJWT +slowapi diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 0000000..f9f06af --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,97 @@ +"""Tests for rate limiting on auth endpoints.""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from fastapi.testclient import TestClient + +from SPARC.api import app + + +@pytest.fixture +def client(): + """Create test client with rate limiter enabled.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def reset_limiter(): + """Reset rate limiter storage between tests.""" + from SPARC.api import limiter + limiter.reset() + yield + + +class TestRateLimiting: + """Test rate limiting on login and register endpoints.""" + + @patch("SPARC.api.get_db_client") + def test_login_allows_requests_under_limit(self, mock_db_client, client): + """Login endpoint allows requests under the rate limit.""" + mock_db = MagicMock() + mock_db.authenticate_user.return_value = None + mock_db_client.return_value = mock_db + + # Should allow at least a few requests + for _ in range(5): + response = client.post( + "/auth/login", + json={"email": "test@example.com", "password": "password123"}, + ) + # 401 is expected (invalid credentials), not 429 + assert response.status_code == 401 + + @patch("SPARC.api.get_db_client") + def test_login_rate_limited_after_threshold(self, mock_db_client, client): + """Login endpoint returns 429 after exceeding rate limit.""" + mock_db = MagicMock() + mock_db.authenticate_user.return_value = None + mock_db_client.return_value = mock_db + + # Send more than the limit (10/minute) + statuses = [] + for _ in range(15): + response = client.post( + "/auth/login", + json={"email": "test@example.com", "password": "password123"}, + ) + statuses.append(response.status_code) + + # At least one should be 429 + assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}" + + @patch("SPARC.api.get_db_client") + def test_register_rate_limited_after_threshold(self, mock_db_client, client): + """Register endpoint returns 429 after exceeding rate limit.""" + mock_db = MagicMock() + mock_db.get_user_count.return_value = 1 + mock_db.create_user.return_value = None # triggers 400 (email exists) + mock_db_client.return_value = mock_db + + # Send more than the limit (5/minute) + statuses = [] + for _ in range(10): + response = client.post( + "/auth/register", + json={"email": "test@example.com", "password": "password123"}, + ) + statuses.append(response.status_code) + + # At least one should be 429 + assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}" + + @patch("SPARC.api.get_db_client") + def test_rate_limit_returns_retry_after_header(self, mock_db_client, client): + """Rate limited responses include a Retry-After header.""" + mock_db = MagicMock() + mock_db.authenticate_user.return_value = None + mock_db_client.return_value = mock_db + + # Exhaust the limit + for _ in range(15): + response = client.post( + "/auth/login", + json={"email": "test@example.com", "password": "password123"}, + ) + if response.status_code == 429: + assert "Retry-After" in response.headers + break