"""Tests for the /admin/rate-limits endpoint.""" from unittest.mock import patch import pytest from fastapi.testclient import TestClient from SPARC import api from SPARC.api import app from SPARC.auth import UserResponse @pytest.fixture def client(): """Create test client.""" return TestClient(app) @pytest.fixture(autouse=True) def reset_stats(): """Reset rate limit stats between tests.""" api._rate_limit_stats.clear() api._rejected_log.clear() yield api._rate_limit_stats.clear() api._rejected_log.clear() def _mock_admin(): """Return a mock admin user.""" return UserResponse(id=1, email="admin@test.com", role="admin", created_at="2025-01-01T00:00:00") def _mock_user(): """Return a mock non-admin user.""" return UserResponse(id=2, email="user@test.com", role="user", created_at="2025-01-01T00:00:00") class TestRateLimitAdminEndpoint: """Test GET /admin/rate-limits.""" def test_admin_can_access(self, client): """Admin users should be able to access the rate-limits endpoint.""" app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") assert response.status_code == 200 data = response.json() assert "rate_limits" in data assert isinstance(data["rate_limits"], list) finally: app.dependency_overrides.clear() def test_non_admin_rejected(self, client): """Non-admin users should get 401/403.""" response = client.get("/admin/rate-limits") assert response.status_code in (401, 403) def test_returns_configured_endpoints(self, client): """Should list all rate-limited endpoints.""" app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") assert response.status_code == 200 data = response.json() endpoints = [rl["endpoint"] for rl in data["rate_limits"]] assert "/auth/register" in endpoints assert "/auth/login" in endpoints finally: app.dependency_overrides.clear() def test_empty_state_shows_zero_counts(self, client): """When no requests have been made, counts should be zero.""" app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() for rl in data["rate_limits"]: assert rl["total_requests"] == 0 assert rl["rejected_requests"] == 0 assert rl["by_ip"] == [] assert data["throttled_24h"] == 0 assert data["throttled_over_time"] == [] finally: app.dependency_overrides.clear() def test_tracks_requests(self, client): """After making requests, the stats should reflect them.""" api._track_rate_limit_request("/auth/login", "127.0.0.1") api._track_rate_limit_request("/auth/login", "127.0.0.1") api._track_rate_limit_request("/auth/login", "192.168.1.1", rejected=True) app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login") assert login_stats["total_requests"] == 3 assert login_stats["rejected_requests"] == 1 finally: app.dependency_overrides.clear() def test_includes_limit_config(self, client): """Each endpoint entry should include the rate limit config string.""" app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() for rl in data["rate_limits"]: assert "limit" in rl assert isinstance(rl["limit"], str) finally: app.dependency_overrides.clear() def test_per_ip_breakdown(self, client): """Stats should include per-IP breakdown with total and rejected counts.""" api._track_rate_limit_request("/auth/login", "10.0.0.1") api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) api._track_rate_limit_request("/auth/login", "10.0.0.2") app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login") by_ip = login_stats["by_ip"] assert len(by_ip) == 2 ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1") assert ip1["total"] == 2 assert ip1["rejected"] == 1 ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2") assert ip2["total"] == 1 assert ip2["rejected"] == 0 finally: app.dependency_overrides.clear() def test_throttled_24h_count(self, client): """Should report total throttled requests in the last 24 hours.""" api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True) app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() assert data["throttled_24h"] == 2 finally: app.dependency_overrides.clear() def test_throttled_over_time_structure(self, client): """Throttled-over-time should be a list of {timestamp, count} buckets.""" api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() assert len(data["throttled_over_time"]) >= 1 entry = data["throttled_over_time"][0] assert "timestamp" in entry assert "count" in entry assert entry["count"] >= 1 finally: app.dependency_overrides.clear() def test_response_shape_matches_contract(self, client): """The full response should match the expected shape for the frontend.""" app.dependency_overrides[api.get_current_admin] = _mock_admin try: response = client.get("/admin/rate-limits") data = response.json() # Top-level keys assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"} # Each rate_limit entry for rl in data["rate_limits"]: assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"} finally: app.dependency_overrides.clear()