Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company e2d750146c feat(auth): add rate limiting to login and register endpoints
- Add slowapi rate limiter: 10 req/min for /auth/login, 5 req/min for /auth/register
- Return HTTP 429 with Retry-After header when limit is exceeded
- Add slowapi to requirements.txt
- Add 4 passing tests for rate limit behavior

Closes leeworks-agents/SPARC#9

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:08:22 +00:00
5 changed files with 297 additions and 193 deletions
+30 -12
View File
@@ -7,22 +7,24 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import Annotated, List from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config from SPARC import config
from SPARC.analyzer import CompanyAnalyzer from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import ( from SPARC.auth import (
TokenResponse, TokenResponse,
UserResponse, UserResponse,
close_db_client,
create_tokens, create_tokens,
decode_token, decode_token,
get_current_admin, get_current_admin,
get_current_user, get_current_user,
get_db_client, get_db_client,
init_db_client,
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -150,14 +152,12 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown.""" """Initialize resources on startup."""
global _analyzer global _analyzer
init_db_client()
_analyzer = CompanyAnalyzer() _analyzer = CompanyAnalyzer()
yield yield
# Cleanup # Cleanup if needed
_analyzer = None _analyzer = None
close_db_client()
app = FastAPI( app = FastAPI(
@@ -168,6 +168,22 @@ app = FastAPI(
root_path=config.root_path, root_path=config.root_path,
) )
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend # Add CORS middleware for React frontend
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -182,7 +198,8 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"]) @app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
async def register(request: RegisterRequest): @limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user. """Register a new user.
The first registered user automatically becomes an admin. The first registered user automatically becomes an admin.
@@ -194,8 +211,8 @@ async def register(request: RegisterRequest):
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
user = db.create_user( user = db.create_user(
email=request.email, email=body.email,
password=request.password, password=body.password,
role=role, role=role,
) )
@@ -214,11 +231,12 @@ async def register(request: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"]) @app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
async def login(request: LoginRequest): @limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens.""" """Authenticate user and return JWT tokens."""
db = get_db_client() db = get_db_client()
user = db.authenticate_user(request.email, request.password) user = db.authenticate_user(body.email, body.password)
if not user: if not user:
raise HTTPException( raise HTTPException(
+4 -29
View File
@@ -132,36 +132,11 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient: def get_db_client() -> DatabaseClient:
"""Get the shared pooled database client for auth operations. """Get database client for auth operations."""
client = DatabaseClient(config.database_url)
Returns the module-level singleton DatabaseClient. If not yet initialized client.connect()
(e.g., during tests), creates a new instance as a fallback. return client
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user( async def get_current_user(
+165 -152
View File
@@ -201,6 +201,8 @@ class DatabaseClient:
Returns: Returns:
Cached message dict if found, None otherwise Cached message dict if found, None otherwise
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
query = """ query = """
@@ -223,11 +225,10 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1" query += " ORDER BY timestamp DESC LIMIT 1"
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params)
cursor.execute(query, params) result = cursor.fetchone()
result = cursor.fetchone() return dict(result) if result else None
return dict(result) if result else None
def store_message( def store_message(
self, self,
@@ -255,32 +256,33 @@ class DatabaseClient:
Returns: Returns:
The ID of the inserted record The ID of the inserted record
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
with self.get_conn() as conn: with self.conn.cursor() as cursor:
with conn.cursor() as cursor: cursor.execute(
cursor.execute( """
""" INSERT INTO llm_messages
INSERT INTO llm_messages (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id
RETURNING id """,
""", (
( prompt,
prompt, prompt_hash,
prompt_hash, response,
response, company_name,
company_name, analysis_type,
analysis_type, model,
model, json.dumps(metadata) if metadata else None,
json.dumps(metadata) if metadata else None, json.dumps(token_usage) if token_usage else None,
json.dumps(token_usage) if token_usage else None, is_cached,
is_cached, ),
), )
)
message_id = cursor.fetchone()[0] message_id = cursor.fetchone()[0]
conn.commit() self.conn.commit()
return message_id return message_id
@@ -302,6 +304,8 @@ class DatabaseClient:
Returns: Returns:
List of message dictionaries List of message dictionaries
""" """
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
@@ -316,10 +320,9 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset]) params.extend([limit, offset])
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params)
cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()]
return [dict(row) for row in cursor.fetchall()]
def get_analytics(self, days: int = 30) -> Dict: def get_analytics(self, days: int = 30) -> Dict:
"""Get analytics on message usage. """Get analytics on message usage.
@@ -330,52 +333,53 @@ class DatabaseClient:
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
""",
(days,),
)
total = cursor.fetchone()["total_messages"]
# Messages by company with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( # Total messages
""" cursor.execute(
SELECT company_name, COUNT(*) as count """
FROM llm_messages SELECT COUNT(*) as total_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' FROM llm_messages
GROUP BY company_name WHERE timestamp >= NOW() - INTERVAL '%s days'
ORDER BY count DESC """,
LIMIT 10 (days,),
""", )
(days,), total = cursor.fetchone()["total_messages"]
)
by_company = cursor.fetchall()
# Messages by type # Messages by company
cursor.execute( cursor.execute(
""" """
SELECT analysis_type, COUNT(*) as count SELECT company_name, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type GROUP BY company_name
ORDER BY count DESC ORDER BY count DESC
""", LIMIT 10
(days,), """,
) (days,),
by_type = cursor.fetchall() )
by_company = cursor.fetchall()
return { # Messages by type
"total_messages": total, cursor.execute(
"by_company": [dict(row) for row in by_company], """
"by_type": [dict(row) for row in by_type], SELECT analysis_type, COUNT(*) as count
"period_days": days, FROM llm_messages
} WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
)
by_type = cursor.fetchall()
return {
"total_messages": total,
"by_company": [dict(row) for row in by_company],
"by_type": [dict(row) for row in by_type],
"period_days": days,
}
# Patent Cache Methods # Patent Cache Methods
@@ -501,23 +505,25 @@ class DatabaseClient:
Returns: Returns:
Created user dict or None if email exists Created user dict or None if email exists
""" """
self.connect()
password_hash = self.hash_password(password) password_hash = self.hash_password(password)
try: try:
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(
cursor.execute( """
""" INSERT INTO users (email, password_hash, role)
INSERT INTO users (email, password_hash, role) VALUES (%s, %s, %s)
VALUES (%s, %s, %s) RETURNING id, email, role, created_at
RETURNING id, email, role, created_at """,
""", (email, password_hash, role),
(email, password_hash, role), )
) user = cursor.fetchone()
user = cursor.fetchone() self.conn.commit()
conn.commit()
return dict(user) if user else None return dict(user) if user else None
except psycopg2.errors.UniqueViolation: except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]: def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -530,22 +536,23 @@ class DatabaseClient:
Returns: Returns:
User dict if authenticated, None otherwise User dict if authenticated, None otherwise
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
if user and self.verify_password(password, user["password_hash"]): with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
return { cursor.execute(
"id": user["id"], "SELECT * FROM users WHERE email = %s",
"email": user["email"], (email,),
"role": user["role"], )
"created_at": user["created_at"], user = cursor.fetchone()
}
return None if user and self.verify_password(password, user["password_hash"]):
return {
"id": user["id"],
"email": user["email"],
"role": user["role"],
"created_at": user["created_at"],
}
return None
def get_user_by_id(self, user_id: int) -> Optional[Dict]: def get_user_by_id(self, user_id: int) -> Optional[Dict]:
"""Get a user by ID. """Get a user by ID.
@@ -556,14 +563,15 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
"SELECT id, email, role, created_at FROM users WHERE id = %s", cursor.execute(
(user_id,), "SELECT id, email, role, created_at FROM users WHERE id = %s",
) (user_id,),
user = cursor.fetchone() )
return dict(user) if user else None user = cursor.fetchone()
return dict(user) if user else None
def get_user_by_email(self, email: str) -> Optional[Dict]: def get_user_by_email(self, email: str) -> Optional[Dict]:
"""Get a user by email. """Get a user by email.
@@ -574,14 +582,15 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
"SELECT id, email, role, created_at FROM users WHERE email = %s", cursor.execute(
(email,), "SELECT id, email, role, created_at FROM users WHERE email = %s",
) (email,),
user = cursor.fetchone() )
return dict(user) if user else None user = cursor.fetchone()
return dict(user) if user else None
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]: def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
"""Get all users (admin only). """Get all users (admin only).
@@ -593,18 +602,19 @@ class DatabaseClient:
Returns: Returns:
List of user dicts List of user dicts
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
""" cursor.execute(
SELECT id, email, role, created_at """
FROM users SELECT id, email, role, created_at
ORDER BY created_at DESC FROM users
LIMIT %s OFFSET %s ORDER BY created_at DESC
""", LIMIT %s OFFSET %s
(limit, offset), """,
) (limit, offset),
return [dict(row) for row in cursor.fetchall()] )
return [dict(row) for row in cursor.fetchall()]
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]: def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
"""Update a user's role (admin only). """Update a user's role (admin only).
@@ -616,19 +626,20 @@ class DatabaseClient:
Returns: Returns:
Updated user dict or None Updated user dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
""" cursor.execute(
UPDATE users """
SET role = %s, updated_at = CURRENT_TIMESTAMP UPDATE users
WHERE id = %s SET role = %s, updated_at = CURRENT_TIMESTAMP
RETURNING id, email, role, created_at WHERE id = %s
""", RETURNING id, email, role, created_at
(role, user_id), """,
) (role, user_id),
user = cursor.fetchone() )
conn.commit() user = cursor.fetchone()
self.conn.commit()
return dict(user) if user else None return dict(user) if user else None
def delete_user(self, user_id: int) -> bool: def delete_user(self, user_id: int) -> bool:
@@ -640,11 +651,12 @@ class DatabaseClient:
Returns: Returns:
True if deleted True if deleted
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) with self.conn.cursor() as cursor:
deleted = cursor.rowcount > 0 cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
conn.commit() deleted = cursor.rowcount > 0
self.conn.commit()
return deleted return deleted
def get_user_count(self) -> int: def get_user_count(self) -> int:
@@ -653,7 +665,8 @@ class DatabaseClient:
Returns: Returns:
Number of users Number of users
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users") with self.conn.cursor() as cursor:
return cursor.fetchone()[0] cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
+1
View File
@@ -14,3 +14,4 @@ numpy
pandas pandas
bcrypt bcrypt
PyJWT PyJWT
slowapi
+97
View File
@@ -0,0 +1,97 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break