Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company 3dfa651f2d Add rate limiting dashboard to admin panel
- Enhance GET /admin/rate-limits with per-IP breakdown, 24h throttled
  count, and hourly time-series of rejected requests
- Add _rejected_log deque for time-series tracking of throttled requests
- Add AdminRateLimits React page with auto-refresh (configurable 15s/30s/1m),
  summary cards, throttled-over-time bar chart, endpoint table, per-IP table
- Add TypeScript types (RateLimitStatsResponse) and adminApi.getRateLimits()
- Wire up /admin/rate-limits route and nav link (admin-only)
- Expand unit tests to 10 cases: auth, empty state, per-IP breakdown,
  throttled_24h count, time-series structure, response shape contract

Closes leeworks-agents/SPARC#1686

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:39:45 +00:00
7 changed files with 468 additions and 528 deletions
+115 -152
View File
@@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING:
@@ -248,6 +249,9 @@ app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
@@ -262,6 +266,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) ->
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
@@ -507,10 +516,12 @@ async def get_rate_limit_stats(
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints.
for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns:
List of rate limit stats per endpoint with total/rejected counts
Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
@@ -520,14 +531,45 @@ async def get_rate_limit_stats(
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
})
return {"rate_limits": results}
# Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"])
@@ -675,25 +717,27 @@ async def get_analytics_trends(
# ============== Export Endpoints ==============
class BatchExportRequest(BaseModel):
"""Request model for batch ZIP export of analysis results."""
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
companies: list[CompanyName] = Field(
..., min_length=1, max_length=50, description="List of company names to export"
)
format: str = Field(
default="csv",
pattern="^(csv|pdf)$",
description="Export format: 'csv' or 'pdf'",
)
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
def _fetch_company_rows(db, company_name: str) -> list:
"""Fetch all non-cached analysis rows for *company_name* from the DB.
Returns a list of tuples: (company_name, analysis_type, model, response, timestamp).
Returns an empty list when no results exist.
Returns:
CSV file download
"""
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
@@ -705,24 +749,43 @@ def _fetch_company_rows(db, company_name: str) -> list:
""",
(company_name,),
)
return cur.fetchall()
rows = cur.fetchall()
def _build_company_csv(rows) -> bytes:
"""Render *rows* as CSV bytes."""
import csv
import io
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows:
writer.writerow(row)
return output.getvalue().encode("utf-8")
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
def _build_company_pdf(rows, company_name: str) -> bytes:
"""Render *rows* as PDF bytes using reportlab."""
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io
from reportlab.lib import colors
@@ -737,6 +800,23 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
TableStyle,
)
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
@@ -779,11 +859,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements = []
display_name = rows[0][0]
# Title and date
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [
["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
@@ -805,6 +887,7 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements.append(summary_table)
elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
@@ -816,11 +899,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
)
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = (
response.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"):
if line.strip():
elements.append(Paragraph(line, body_style))
@@ -831,133 +916,11 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
doc.build(elements)
buffer.seek(0)
return buffer.getvalue()
@app.post("/export/batch", tags=["Export"])
async def export_batch_zip(
request: BatchExportRequest,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for multiple companies as a ZIP archive.
For each company in the request, fetches all stored analysis records and
adds a per-company file (CSV or PDF) to the archive. Companies with no
stored results are skipped; a ``manifest.json`` inside the ZIP lists both
the exported and skipped companies.
Args:
request: List of company names and desired export format ('csv' or 'pdf')
Returns:
ZIP archive download containing one file per found company plus a manifest
"""
import io
import json
import zipfile
db = get_db_client()
export_date = datetime.now().strftime("%Y-%m-%d")
fmt = request.format
exported: list[str] = []
skipped: list[str] = []
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for company_name in request.companies:
rows = _fetch_company_rows(db, company_name)
if not rows:
skipped.append(company_name)
continue
safe_name = company_name.replace(" ", "_").lower()
if fmt == "pdf":
file_bytes = _build_company_pdf(rows, company_name)
filename = f"{safe_name}-analysis-{export_date}.pdf"
else:
file_bytes = _build_company_csv(rows)
filename = f"sparc_{safe_name}_export.csv"
zf.writestr(filename, file_bytes)
exported.append(company_name)
# Always include a manifest
manifest = {
"export_date": export_date,
"format": fmt,
"exported": exported,
"skipped": skipped,
}
zf.writestr("manifest.json", json.dumps(manifest, indent=2))
zip_buffer.seek(0)
zip_filename = f"sparc-export-{export_date}.zip"
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{zip_filename}"'},
)
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([_build_company_csv(rows)]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
analysis_date = datetime.now().strftime("%Y-%m-%d")
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse(
iter([_build_company_pdf(rows, company_name)]),
iter([buffer.getvalue()]),
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
+9
View File
@@ -11,6 +11,7 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare';
const queryClient = new QueryClient({
@@ -56,6 +57,14 @@ function App() {
</ProtectedRoute>
}
/>
<Route
path="/admin/rate-limits"
element={
<ProtectedRoute requireAdmin>
<AdminRateLimits />
</ProtectedRoute>
}
/>
</Route>
{/* Default redirect */}
+31
View File
@@ -201,6 +201,32 @@ export const analyticsApi = {
},
};
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API
export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -216,6 +242,11 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`);
},
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
};
export default api;
+2 -1
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react';
export function Layout() {
const { user, isAdmin, logout } = useAuth();
@@ -23,6 +23,7 @@ export function Layout() {
if (isAdmin) {
navItems.push({ to: '/admin/users', icon: Users, label: 'Users' });
navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' });
}
return (
+240
View File
@@ -0,0 +1,240 @@
import { useState } from 'react';
import { useQuery } from '@tanstack/react-query';
import { adminApi } from '../api/client';
import type { RateLimitStatsResponse } from '../api/client';
import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react';
const REFRESH_OPTIONS = [
{ label: '15s', value: 15_000 },
{ label: '30s', value: 30_000 },
{ label: '1m', value: 60_000 },
{ label: 'Off', value: 0 },
];
export function AdminRateLimits() {
const [refreshInterval, setRefreshInterval] = useState(30_000);
const { data, isLoading, isError, dataUpdatedAt } = useQuery<RateLimitStatsResponse>({
queryKey: ['admin-rate-limits'],
queryFn: () => adminApi.getRateLimits(),
refetchInterval: refreshInterval || false,
});
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div>
);
}
if (isError) {
return (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load rate limit statistics.</span>
</div>
);
}
const maxThrottledCount = data?.throttled_over_time?.length
? Math.max(...data.throttled_over_time.map((b) => b.count))
: 0;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between flex-wrap gap-4">
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Rate Limiting Dashboard
</h2>
<p className="text-text-secondary">Monitor API rate limits and throttled requests.</p>
</div>
<div className="flex items-center gap-3">
{/* Last updated */}
{dataUpdatedAt > 0 && (
<span className="text-xs text-text-secondary flex items-center gap-1">
<Clock size={12} />
Updated {new Date(dataUpdatedAt).toLocaleTimeString()}
</span>
)}
{/* Refresh interval selector */}
<div className="flex items-center gap-1 bg-bg-card/60 border border-primary/15 rounded-xl p-1">
<RefreshCw size={14} className="text-text-secondary ml-2" />
{REFRESH_OPTIONS.map((opt) => (
<button
key={opt.value}
onClick={() => setRefreshInterval(opt.value)}
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
refreshInterval === opt.value
? 'bg-primary text-white'
: 'text-text-secondary hover:text-text-primary hover:bg-bg-card-hover'
}`}
>
{opt.label}
</button>
))}
</div>
</div>
</div>
{/* Summary cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<Activity size={18} className="text-primary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-error/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-error" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Throttled (24h)
</span>
</div>
<div className="text-3xl font-bold text-error">
{data?.throttled_24h ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-secondary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-secondary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rate-Limited Endpoints
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.length ?? 0}
</div>
</div>
</div>
{/* Throttled over time chart (simple bar chart) */}
{data?.throttled_over_time && data.throttled_over_time.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider mb-4">
Throttled Requests Over Time (Last 24h)
</h3>
<div className="flex items-end gap-1 h-32">
{data.throttled_over_time.map((bucket) => {
const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0;
const hour = new Date(bucket.timestamp).getHours();
return (
<div key={bucket.timestamp} className="flex-1 flex flex-col items-center gap-1">
<span className="text-xs text-text-secondary">{bucket.count}</span>
<div
className="w-full bg-error/70 rounded-t-sm min-h-[2px] transition-all"
style={{ height: `${Math.max(height, 2)}%` }}
title={`${bucket.timestamp}: ${bucket.count} throttled`}
/>
<span className="text-[10px] text-text-secondary">{hour}:00</span>
</div>
);
})}
</div>
</div>
)}
{/* Per-endpoint table */}
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Limit
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data?.rate_limits.map((rl) => (
<tr key={rl.endpoint} className="hover:bg-bg-card-hover/50 transition-colors">
<td className="px-6 py-4 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-4">
<span className="inline-flex px-2 py-0.5 rounded-full text-xs font-medium bg-primary/10 text-primary border border-primary/20">
{rl.limit}
</span>
</td>
<td className="px-6 py-4 text-right text-text-primary font-semibold">
{rl.total_requests}
</td>
<td className="px-6 py-4 text-right">
<span className={rl.rejected_requests > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{rl.rejected_requests}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Per-IP breakdown */}
{data?.rate_limits.some((rl) => rl.by_ip.length > 0) && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="px-6 py-4 border-b border-primary/10">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Per-IP Breakdown
</h3>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
IP Address
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data.rate_limits.flatMap((rl) =>
rl.by_ip.map((ipEntry) => (
<tr
key={`${rl.endpoint}-${ipEntry.ip}`}
className="hover:bg-bg-card-hover/50 transition-colors"
>
<td className="px-6 py-3 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-3 font-mono text-sm text-text-secondary">{ipEntry.ip}</td>
<td className="px-6 py-3 text-right text-text-primary">{ipEntry.total}</td>
<td className="px-6 py-3 text-right">
<span className={ipEntry.rejected > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{ipEntry.rejected}
</span>
</td>
</tr>
))
)}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}
-373
View File
@@ -1,373 +0,0 @@
"""Tests for POST /export/batch endpoint (issue #1674).
Covers:
- Single company export (CSV + PDF)
- Multiple company export
- All-missing companies (every requested company is skipped)
- Unauthenticated / invalid-token requests
- Manifest content validation
- Invalid format rejection
"""
import io
import json
import zipfile
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create a FastAPI test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock database client for all tests in this module."""
db = MagicMock()
# Auth: user always exists
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
# Default cursor mock (overridden per-test via side_effect or return_value)
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
db._mock_cursor = mock_cursor
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _rows_for(company_name: str):
"""Return a single sample row for the given company."""
return [
(
company_name,
"company_analysis",
"anthropic/claude-3.5-sonnet",
f"Strong patent portfolio for {company_name}.",
datetime(2025, 6, 15, 10, 30, 0),
)
]
def _open_zip(content: bytes) -> zipfile.ZipFile:
"""Helper: wrap response bytes as a ZipFile."""
return zipfile.ZipFile(io.BytesIO(content))
# ---------------------------------------------------------------------------
# Authentication
# ---------------------------------------------------------------------------
class TestBatchExportAuth:
"""Unauthenticated and invalid-token requests must be rejected."""
def test_unauthenticated_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
)
assert response.status_code == 401
def test_invalid_token_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers={"Authorization": "Bearer totally.invalid.token"},
)
assert response.status_code == 401
# ---------------------------------------------------------------------------
# Single company
# ---------------------------------------------------------------------------
class TestBatchExportSingleCompany:
"""POST /export/batch with a single company name."""
def test_single_company_csv_returns_zip(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/zip"
assert "attachment" in response.headers["content-disposition"]
assert "sparc-export-" in response.headers["content-disposition"]
assert response.headers["content-disposition"].endswith('.zip"')
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
names = zf.namelist()
csv_files = [n for n in names if n.endswith(".csv")]
assert len(csv_files) == 1
assert "nvidia" in csv_files[0]
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
csv_text = zf.read(csv_name).decode("utf-8")
lines = csv_text.strip().split("\n")
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
assert "NVIDIA" in lines[1]
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
assert len(pdf_files) == 1
# Verify it is actually a PDF (starts with %PDF)
pdf_bytes = zf.read(pdf_files[0])
assert pdf_bytes[:4] == b"%PDF"
# ---------------------------------------------------------------------------
# Multiple companies
# ---------------------------------------------------------------------------
class TestBatchExportMultipleCompanies:
"""POST /export/batch with several companies."""
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
companies = ["NVIDIA", "Intel", "AMD"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
_rows_for("AMD"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 3
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
companies = ["NVIDIA", "Intel"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
assert manifest["skipped"] == []
assert manifest["format"] == "csv"
def test_partial_missing_companies_skipped(self, client, mock_db):
"""Companies with no data are skipped; others are exported."""
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
[], # no data for "UnknownCo"
]
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == ["NVIDIA"]
assert manifest["skipped"] == ["UnknownCo"]
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 1
# ---------------------------------------------------------------------------
# All-missing companies
# ---------------------------------------------------------------------------
class TestBatchExportAllMissing:
"""When every requested company has no data, the ZIP still returns 200
with only a manifest (no per-company files, all listed in skipped)."""
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == []
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
data_files = [n for n in zf.namelist() if n != "manifest.json"]
assert data_files == []
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
class TestBatchExportManifest:
"""The manifest.json inside every ZIP must be well-formed."""
def test_manifest_always_present(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
def test_manifest_contains_required_keys(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert "export_date" in manifest
assert "format" in manifest
assert "exported" in manifest
assert "skipped" in manifest
def test_manifest_format_field_matches_request(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "pdf"
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
class TestBatchExportInputValidation:
"""Invalid request bodies must return 422."""
def test_invalid_format_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "xlsx"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_empty_companies_list_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": [], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_default_format_is_csv(self, client, mock_db):
"""Omitting `format` should default to CSV."""
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"]},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "csv"
+71 -2
View File
@@ -20,8 +20,10 @@ def client():
def reset_stats():
"""Reset rate limit stats between tests."""
api._rate_limit_stats.clear()
api._rejected_log.clear()
yield
api._rate_limit_stats.clear()
api._rejected_log.clear()
def _mock_admin():
@@ -50,8 +52,7 @@ class TestRateLimitAdminEndpoint:
app.dependency_overrides.clear()
def test_non_admin_rejected(self, client):
"""Non-admin users should get 403."""
# Without overriding the dependency, it should fail auth
"""Non-admin users should get 401/403."""
response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403)
@@ -77,6 +78,9 @@ class TestRateLimitAdminEndpoint:
for rl in data["rate_limits"]:
assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0
assert rl["by_ip"] == []
assert data["throttled_24h"] == 0
assert data["throttled_over_time"] == []
finally:
app.dependency_overrides.clear()
@@ -107,3 +111,68 @@ class TestRateLimitAdminEndpoint:
assert isinstance(rl["limit"], str)
finally:
app.dependency_overrides.clear()
def test_per_ip_breakdown(self, client):
"""Stats should include per-IP breakdown with total and rejected counts."""
api._track_rate_limit_request("/auth/login", "10.0.0.1")
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/login", "10.0.0.2")
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
by_ip = login_stats["by_ip"]
assert len(by_ip) == 2
ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1")
assert ip1["total"] == 2
assert ip1["rejected"] == 1
ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2")
assert ip2["total"] == 1
assert ip2["rejected"] == 0
finally:
app.dependency_overrides.clear()
def test_throttled_24h_count(self, client):
"""Should report total throttled requests in the last 24 hours."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert data["throttled_24h"] == 2
finally:
app.dependency_overrides.clear()
def test_throttled_over_time_structure(self, client):
"""Throttled-over-time should be a list of {timestamp, count} buckets."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert len(data["throttled_over_time"]) >= 1
entry = data["throttled_over_time"][0]
assert "timestamp" in entry
assert "count" in entry
assert entry["count"] >= 1
finally:
app.dependency_overrides.clear()
def test_response_shape_matches_contract(self, client):
"""The full response should match the expected shape for the frontend."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
# Top-level keys
assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"}
# Each rate_limit entry
for rl in data["rate_limits"]:
assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"}
finally:
app.dependency_overrides.clear()