Compare commits

..

8 Commits

Author SHA1 Message Date
agent-company 8f40109272 Add POST /export/batch endpoint for multi-company ZIP download
Implements issue #1674: a new authenticated POST /export/batch endpoint
that accepts a list of company names and an optional format (csv or pdf),
compiles per-company exports into a ZIP archive using Python's zipfile
module, and returns it as a streaming download.

Key changes:
- Extract `_fetch_company_rows`, `_build_company_csv`, `_build_company_pdf`
  helpers to eliminate duplication between the single-company endpoints and
  the new batch endpoint
- Refactor `export_company_csv` and `export_company_pdf` to delegate to the
  new helpers
- Add `BatchExportRequest` Pydantic model (companies list + format field)
- Add `POST /export/batch` which iterates over companies, skips those with
  no data, writes per-company files into the ZIP, and always includes a
  `manifest.json` listing exported and skipped companies
- Response header: `Content-Disposition: attachment; filename=sparc-export-<date>.zip`
- 17 new tests covering: single company (CSV + PDF), multiple companies,
  all-missing, unauthenticated, invalid-token, manifest structure, input
  validation

Closes leeworks-agents/SPARC#1674

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:21:09 +00:00
AI-Manager 313800215c Merge pull request 'Add rate limit stats to admin panel' (#1682) from feature/1675-rate-limit-admin into main
Merge PR #1682
2026-05-19 00:12:56 +00:00
AI-Manager 222f29deb1 Merge pull request 'Add cursor-based pagination to /analyze/batch and /jobs' (#1681) from feature/1669-cursor-pagination into main
Merge PR #1681
2026-05-19 00:12:48 +00:00
AI-Manager e6d95bbf57 Merge pull request 'Add stricter input validation for company names' (#1680) from feature/1670-company-name-validation into main
Merge PR #1680
2026-05-19 00:12:42 +00:00
AI-Manager 68484ef4b1 Merge pull request 'Update ROADMAP.md: mark completed P1 and P2 items as done' (#1679) from feature/1678-update-roadmap into main
Merge PR #1679
2026-05-19 00:12:34 +00:00
agent-company a0cb9a5773 Add rate limit status and usage statistics to admin panel
Add GET /admin/rate-limits endpoint (admin-only) that returns current
rate limit configuration and request statistics for all rate-limited
endpoints (/auth/register and /auth/login). Tracks total requests and
rejection counts via in-memory counters.

Includes tests for admin access, non-admin rejection, empty state,
request tracking, and configuration display.

Closes leeworks-agents/SPARC#1675

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:53:01 +00:00
agent-company a95129904e Add stricter input validation for company names on analysis endpoints
Add a CompanyName validated type enforcing 2-100 character length and
allowing only alphanumeric characters, spaces, hyphens, ampersands, and
periods. Applied to all endpoints accepting company names: /analyze,
/analyze/patent, /analyze/batch, /admin/tracked, and /export.

Includes unit tests covering too-short, too-long, special character,
leading-character, and valid edge cases for both single and batch
endpoints.

Closes leeworks-agents/SPARC#1670

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:38:44 +00:00
agent-company 7c6eed8d72 Update ROADMAP.md to mark completed P1 and P2 items as done
Move seven completed items from the P1 and P2 sections into the
Completed section: in-memory jobs persistence, export endpoint tests,
tracked company admin tests, webhook integration tests, S3 storage
tests, auto-download path tests, and scheduler DatabaseClient refactor.

The P2 section now only lists the two genuinely open items: cursor-based
pagination (Issue #1669) and request validation (Issue #1670).

Closes leeworks-agents/SPARC#1678

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:29:14 +00:00
5 changed files with 893 additions and 113 deletions
+30 -37
View File
@@ -81,57 +81,50 @@ Items that have been implemented and merged into main.
- ~~OpenAPI client generation.~~ TypeScript API client auto-generated from - ~~OpenAPI client generation.~~ TypeScript API client auto-generated from
FastAPI spec with CI freshness check. FastAPI spec with CI freshness check.
### Resilience
- ~~`_jobs` dict is in-memory only.~~ Database-backed job persistence
implemented using `db.list_jobs()` and `mark_stale_jobs_failed()`. The
in-memory `_jobs` dict has been removed.
### Test coverage (P1/P2)
- ~~Export endpoint tests.~~ Tests added for CSV and PDF export endpoints.
- ~~Tracked company admin endpoint tests.~~ Tests added for `/admin/tracked`
CRUD endpoints and scheduler integration.
- ~~Webhook integration tests.~~ Tests added for retry logic, Slack/Discord
payload format, and multi-URL dispatch.
- ~~S3/MinIO storage backend tests.~~ Unit tests added for the S3 backend
(read, write, exists, delete, error handling).
- ~~`analyze_single_patent` auto-download path tests.~~ Tests added for the
auto-download fallback (cache lookup, PDF download, FileNotFoundError).
### Code quality
- ~~Scheduler creates its own DatabaseClient.~~ Refactored to use the
application-level pooled `get_db_client()`.
--- ---
## P1 -- High Priority ## P1 -- High Priority
These items address correctness, reliability, and coverage gaps that should be No outstanding P1 items. All previously listed items have been completed and
resolved before broader production use. moved to the Completed section above.
### Resilience
- **`_jobs` dict is in-memory only.** Job state is lost on API restart.
Persist job status in PostgreSQL or Redis so async batch results survive
restarts.
### Test coverage gaps
- **Export endpoint tests.** The CSV and PDF export endpoints (`/export/`)
lack test coverage. Add tests covering auth, success, 404, and edge cases.
*(Issue #1655)*
- **Tracked company admin endpoint tests.** The `/admin/tracked` CRUD
endpoints and scheduler integration lack test coverage. *(Issue #1656)*
--- ---
## P2 -- Medium Priority ## P2 -- Medium Priority
Improvements to reliability, test coverage, and code quality. Improvements to the API surface.
### Test coverage
- **Webhook integration tests.** The retry logic, Slack/Discord payload
format, and multi-URL dispatch in `webhooks.py` need test coverage.
*(Issue #1657)*
- **S3/MinIO storage backend tests.** `storage.py` has local filesystem tests
but no unit tests for the S3 backend (read, write, exists, delete,
error handling). *(Issue #1660)*
- **`analyze_single_patent` auto-download path tests.** The auto-download
fallback (cache lookup, PDF download, FileNotFoundError) in
`analyzer.py` lacks test coverage. *(Issue #1661)*
### Code quality
- **Scheduler creates its own DatabaseClient.** `scheduler.py` bypasses the
application-level pooled client, creating a new connection on every tick.
Refactor to use `get_db_client()`. *(Issue #1658)*
### API improvements ### API improvements
- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit - **API pagination.** The `/analyze/batch` endpoint needs cursor-based
from cursor-based pagination for large result sets. pagination for large result sets. The `/jobs` endpoint already has cursor
pagination. *(Issue #1669)*
- **Request validation improvements.** Add stricter input validation for - **Request validation improvements.** Add stricter input validation for
company names (disallow special characters, enforce length limits). company names (disallow special characters, enforce length limits).
*(Issue #1670)*
--- ---
+224 -76
View File
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING: if TYPE_CHECKING:
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field, StringConstraints
from slowapi import Limiter from slowapi import Limiter
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
@@ -36,6 +36,16 @@ from SPARC.auth import (
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[
str,
StringConstraints(
min_length=2,
max_length=100,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
),
]
# Pydantic models for API # Pydantic models for API
class CompanyAnalysisResponse(BaseModel): class CompanyAnalysisResponse(BaseModel):
@@ -72,7 +82,7 @@ class CompanyAnalysisRequest(BaseModel):
class BatchAnalysisRequest(BaseModel): class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis.""" """Request model for batch company analysis."""
companies: list[str] = Field( companies: list[CompanyName] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze" ..., min_length=1, max_length=20, description="List of company names to analyze"
) )
max_workers: int = Field( max_workers: int = Field(
@@ -235,10 +245,37 @@ app = FastAPI(
limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
key = endpoint
if key not in _rate_limit_stats:
_rate_limit_stats[key] = {
"endpoint": endpoint,
"total_requests": 0,
"rejected_requests": 0,
"by_ip": {},
}
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
ip_stats[ip]["total"] += 1
if rejected:
ip_stats[ip]["rejected"] += 1
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded): async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded.""" """Return 429 with Retry-After header when rate limit is exceeded."""
endpoint = request.url.path
ip = get_remote_address(request)
_track_rate_limit_request(endpoint, ip, rejected=True)
retry_after = getattr(exc, "retry_after", 60) retry_after = getattr(exc, "retry_after", 60)
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
@@ -267,6 +304,7 @@ async def register(request: Request, body: RegisterRequest):
The first registered user automatically becomes an admin. The first registered user automatically becomes an admin.
""" """
_track_rate_limit_request("/auth/register", get_remote_address(request))
db = get_db_client() db = get_db_client()
# First user becomes admin # First user becomes admin
@@ -297,6 +335,7 @@ async def register(request: Request, body: RegisterRequest):
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest): async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens.""" """Authenticate user and return JWT tokens."""
_track_rate_limit_request("/auth/login", get_remote_address(request))
db = get_db_client() db = get_db_client()
user = db.authenticate_user(body.email, body.password) user = db.authenticate_user(body.email, body.password)
@@ -423,7 +462,7 @@ async def delete_user(
class TrackCompanyRequest(BaseModel): class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking.""" """Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255) company_name: CompanyName = Field(...)
@app.get("/admin/tracked", tags=["Admin"]) @app.get("/admin/tracked", tags=["Admin"])
@@ -450,7 +489,7 @@ async def add_tracked_company(
@app.delete("/admin/tracked/{company_name}", tags=["Admin"]) @app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company( async def remove_tracked_company(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""Remove a company from the tracked list (admin only).""" """Remove a company from the tracked list (admin only)."""
@@ -461,6 +500,36 @@ async def remove_tracked_company(
return {"message": f"Stopped tracking {company_name}"} return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/rate-limits", tags=["Admin"])
async def get_rate_limit_stats(
_: UserResponse = Depends(get_current_admin),
):
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints.
Returns:
List of rate limit stats per endpoint with total/rejected counts
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
"/auth/login": {"limit": "10/minute"},
}
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
})
return {"rate_limits": results}
@app.get("/admin/alerts", tags=["Admin"]) @app.get("/admin/alerts", tags=["Admin"])
async def list_alerts( async def list_alerts(
limit: int = Query(default=50, ge=1, le=200), limit: int = Query(default=50, ge=1, le=200),
@@ -606,27 +675,25 @@ async def get_analytics_trends(
# ============== Export Endpoints ============== # ============== Export Endpoints ==============
@app.get("/export/{company_name}", tags=["Export"]) class BatchExportRequest(BaseModel):
async def export_company_csv( """Request model for batch ZIP export of analysis results."""
company_name: str,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including companies: list[CompanyName] = Field(
analysis type, model used, response text, and timestamp. ..., min_length=1, max_length=50, description="List of company names to export"
)
format: str = Field(
default="csv",
pattern="^(csv|pdf)$",
description="Export format: 'csv' or 'pdf'",
)
Args:
company_name: Company name to export results for
Returns: def _fetch_company_rows(db, company_name: str) -> list:
CSV file download """Fetch all non-cached analysis rows for *company_name* from the DB.
Returns a list of tuples: (company_name, analysis_type, model, response, timestamp).
Returns an empty list when no results exist.
""" """
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn: with db.get_conn() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute( cur.execute(
@@ -638,43 +705,24 @@ async def export_company_csv(
""", """,
(company_name,), (company_name,),
) )
rows = cur.fetchall() return cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'") def _build_company_csv(rows) -> bytes:
"""Render *rows* as CSV bytes."""
import csv
import io
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"]) writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows: for row in rows:
writer.writerow(row) writer.writerow(row)
return output.getvalue().encode("utf-8")
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"]) def _build_company_pdf(rows, company_name: str) -> bytes:
async def export_company_pdf( """Render *rows* as PDF bytes using reportlab."""
company_name: str,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io import io
from reportlab.lib import colors from reportlab.lib import colors
@@ -689,23 +737,6 @@ async def export_company_pdf(
TableStyle, TableStyle,
) )
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO() buffer = io.BytesIO()
doc = SimpleDocTemplate( doc = SimpleDocTemplate(
buffer, buffer,
@@ -748,13 +779,11 @@ async def export_company_pdf(
elements = [] elements = []
# Title and date display_name = rows[0][0]
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d") analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style)) elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style)) elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [ summary_data = [
["Total Analyses", str(len(rows))], ["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))], ["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
@@ -776,7 +805,6 @@ async def export_company_pdf(
elements.append(summary_table) elements.append(summary_table)
elements.append(Spacer(1, 16)) elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1): for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row _, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp) ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
@@ -788,13 +816,11 @@ async def export_company_pdf(
Paragraph(f"<i>Performed: {ts_str}</i>", body_style) Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
) )
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = ( safe_response = (
response.replace("&", "&amp;") response.replace("&", "&amp;")
.replace("<", "&lt;") .replace("<", "&lt;")
.replace(">", "&gt;") .replace(">", "&gt;")
) )
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"): for line in safe_response.split("\n"):
if line.strip(): if line.strip():
elements.append(Paragraph(line, body_style)) elements.append(Paragraph(line, body_style))
@@ -805,11 +831,133 @@ async def export_company_pdf(
doc.build(elements) doc.build(elements)
buffer.seek(0) buffer.seek(0)
return buffer.getvalue()
@app.post("/export/batch", tags=["Export"])
async def export_batch_zip(
request: BatchExportRequest,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for multiple companies as a ZIP archive.
For each company in the request, fetches all stored analysis records and
adds a per-company file (CSV or PDF) to the archive. Companies with no
stored results are skipped; a ``manifest.json`` inside the ZIP lists both
the exported and skipped companies.
Args:
request: List of company names and desired export format ('csv' or 'pdf')
Returns:
ZIP archive download containing one file per found company plus a manifest
"""
import io
import json
import zipfile
db = get_db_client()
export_date = datetime.now().strftime("%Y-%m-%d")
fmt = request.format
exported: list[str] = []
skipped: list[str] = []
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for company_name in request.companies:
rows = _fetch_company_rows(db, company_name)
if not rows:
skipped.append(company_name)
continue
safe_name = company_name.replace(" ", "_").lower() safe_name = company_name.replace(" ", "_").lower()
if fmt == "pdf":
file_bytes = _build_company_pdf(rows, company_name)
filename = f"{safe_name}-analysis-{export_date}.pdf"
else:
file_bytes = _build_company_csv(rows)
filename = f"sparc_{safe_name}_export.csv"
zf.writestr(filename, file_bytes)
exported.append(company_name)
# Always include a manifest
manifest = {
"export_date": export_date,
"format": fmt,
"exported": exported,
"skipped": skipped,
}
zf.writestr("manifest.json", json.dumps(manifest, indent=2))
zip_buffer.seek(0)
zip_filename = f"sparc-export-{export_date}.zip"
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{zip_filename}"'},
)
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([_build_company_csv(rows)]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
analysis_date = datetime.now().strftime("%Y-%m-%d")
filename = f"{safe_name}-analysis-{analysis_date}.pdf" filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse( return StreamingResponse(
iter([buffer.getvalue()]), iter([_build_company_pdf(rows, company_name)]),
media_type="application/pdf", media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'}, headers={"Content-Disposition": f'attachment; filename="{filename}"'},
) )
@@ -834,7 +982,7 @@ async def health_check():
tags=["Analysis"], tags=["Analysis"],
) )
async def analyze_company( async def analyze_company(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
@@ -864,7 +1012,7 @@ async def analyze_company(
) )
async def analyze_single_patent( async def analyze_single_patent(
patent_id: str, patent_id: str,
company_name: str = Query(description="Company name for analysis context"), company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single patent by its publication ID. """Analyze a single patent by its publication ID.
+373
View File
@@ -0,0 +1,373 @@
"""Tests for POST /export/batch endpoint (issue #1674).
Covers:
- Single company export (CSV + PDF)
- Multiple company export
- All-missing companies (every requested company is skipped)
- Unauthenticated / invalid-token requests
- Manifest content validation
- Invalid format rejection
"""
import io
import json
import zipfile
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create a FastAPI test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock database client for all tests in this module."""
db = MagicMock()
# Auth: user always exists
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
# Default cursor mock (overridden per-test via side_effect or return_value)
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
db._mock_cursor = mock_cursor
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _rows_for(company_name: str):
"""Return a single sample row for the given company."""
return [
(
company_name,
"company_analysis",
"anthropic/claude-3.5-sonnet",
f"Strong patent portfolio for {company_name}.",
datetime(2025, 6, 15, 10, 30, 0),
)
]
def _open_zip(content: bytes) -> zipfile.ZipFile:
"""Helper: wrap response bytes as a ZipFile."""
return zipfile.ZipFile(io.BytesIO(content))
# ---------------------------------------------------------------------------
# Authentication
# ---------------------------------------------------------------------------
class TestBatchExportAuth:
"""Unauthenticated and invalid-token requests must be rejected."""
def test_unauthenticated_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
)
assert response.status_code == 401
def test_invalid_token_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers={"Authorization": "Bearer totally.invalid.token"},
)
assert response.status_code == 401
# ---------------------------------------------------------------------------
# Single company
# ---------------------------------------------------------------------------
class TestBatchExportSingleCompany:
"""POST /export/batch with a single company name."""
def test_single_company_csv_returns_zip(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/zip"
assert "attachment" in response.headers["content-disposition"]
assert "sparc-export-" in response.headers["content-disposition"]
assert response.headers["content-disposition"].endswith('.zip"')
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
names = zf.namelist()
csv_files = [n for n in names if n.endswith(".csv")]
assert len(csv_files) == 1
assert "nvidia" in csv_files[0]
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
csv_text = zf.read(csv_name).decode("utf-8")
lines = csv_text.strip().split("\n")
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
assert "NVIDIA" in lines[1]
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
assert len(pdf_files) == 1
# Verify it is actually a PDF (starts with %PDF)
pdf_bytes = zf.read(pdf_files[0])
assert pdf_bytes[:4] == b"%PDF"
# ---------------------------------------------------------------------------
# Multiple companies
# ---------------------------------------------------------------------------
class TestBatchExportMultipleCompanies:
"""POST /export/batch with several companies."""
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
companies = ["NVIDIA", "Intel", "AMD"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
_rows_for("AMD"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 3
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
companies = ["NVIDIA", "Intel"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
assert manifest["skipped"] == []
assert manifest["format"] == "csv"
def test_partial_missing_companies_skipped(self, client, mock_db):
"""Companies with no data are skipped; others are exported."""
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
[], # no data for "UnknownCo"
]
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == ["NVIDIA"]
assert manifest["skipped"] == ["UnknownCo"]
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 1
# ---------------------------------------------------------------------------
# All-missing companies
# ---------------------------------------------------------------------------
class TestBatchExportAllMissing:
"""When every requested company has no data, the ZIP still returns 200
with only a manifest (no per-company files, all listed in skipped)."""
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == []
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
data_files = [n for n in zf.namelist() if n != "manifest.json"]
assert data_files == []
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
class TestBatchExportManifest:
"""The manifest.json inside every ZIP must be well-formed."""
def test_manifest_always_present(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
def test_manifest_contains_required_keys(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert "export_date" in manifest
assert "format" in manifest
assert "exported" in manifest
assert "skipped" in manifest
def test_manifest_format_field_matches_request(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "pdf"
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
class TestBatchExportInputValidation:
"""Invalid request bodies must return 422."""
def test_invalid_format_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "xlsx"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_empty_companies_list_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": [], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_default_format_is_csv(self, client, mock_db):
"""Omitting `format` should default to CSV."""
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"]},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "csv"
+157
View File
@@ -0,0 +1,157 @@
"""Tests for company name input validation on analysis endpoints."""
from datetime import datetime
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer so valid requests succeed."""
mock = Mock()
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
company_name="nvidia",
analysis="Test analysis",
patent_count=1,
success=True,
timestamp=datetime.now(),
)
mocker.patch("SPARC.api._analyzer", mock)
return mock
class TestCompanyNameValidation:
"""Test that company names are validated on analysis endpoints."""
# --- Too short ---
def test_single_char_rejected(self, client, mock_analyzer):
"""A one-character company name should be rejected."""
response = client.get("/analyze/X")
assert response.status_code == 422
# --- Too long ---
def test_over_100_chars_rejected(self, client, mock_analyzer):
"""A company name longer than 100 characters should be rejected."""
long_name = "A" * 101
response = client.get(f"/analyze/{long_name}")
assert response.status_code == 422
# --- Special characters ---
@pytest.mark.parametrize(
"bad_name",
[
"nvidia!",
"intel@corp",
"test#company",
"foo$bar",
"a%b",
"x^y",
"semi;colon",
"drop'table",
'say"hello',
"path/traversal",
"back\\slash",
"pipe|char",
"star*glob",
"question?mark",
"<script>",
"curly{brace}",
"equal=sign",
"plus+plus",
"comma,separated",
],
)
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
"""Company names with disallowed special characters should be rejected."""
response = client.get(f"/analyze/{bad_name}")
assert response.status_code == 422
# --- Valid names ---
@pytest.mark.parametrize(
"valid_name",
[
"nvidia",
"Intel",
"TSMC",
"Texas Instruments",
"Johnson-Johnson",
"AT&T",
"St. Jude Medical",
"3M",
"21st Century Fox",
"ab", # minimum length
"A" * 100, # maximum length
],
)
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
"""Valid company names should be accepted (200, not 422)."""
response = client.get(f"/analyze/{valid_name}")
# Should not be a validation error; 200 or other non-422 status is fine
assert response.status_code != 422
# --- Batch endpoint validation ---
def test_batch_too_short_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too short."""
response = client.post(
"/analyze/batch",
json={"companies": ["X"]},
)
assert response.status_code == 422
def test_batch_too_long_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too long."""
response = client.post(
"/analyze/batch",
json={"companies": ["A" * 101]},
)
assert response.status_code == 422
def test_batch_special_chars_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names with special chars."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia!", "intel"]},
)
assert response.status_code == 422
def test_batch_valid_names_accepted(self, client, mock_analyzer):
"""Batch endpoint should accept valid company names."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "Intel", "AT&T"]},
)
assert response.status_code != 422
# --- Name must start with alphanumeric ---
def test_leading_space_rejected(self, client, mock_analyzer):
"""Company name starting with a space should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": [" nvidia"]},
)
assert response.status_code == 422
def test_leading_hyphen_rejected(self, client, mock_analyzer):
"""Company name starting with a hyphen should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": ["-nvidia"]},
)
assert response.status_code == 422
+109
View File
@@ -0,0 +1,109 @@
"""Tests for the /admin/rate-limits endpoint."""
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from SPARC import api
from SPARC.api import app
from SPARC.auth import UserResponse
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_stats():
"""Reset rate limit stats between tests."""
api._rate_limit_stats.clear()
yield
api._rate_limit_stats.clear()
def _mock_admin():
"""Return a mock admin user."""
return UserResponse(id=1, email="admin@test.com", role="admin", created_at="2025-01-01T00:00:00")
def _mock_user():
"""Return a mock non-admin user."""
return UserResponse(id=2, email="user@test.com", role="user", created_at="2025-01-01T00:00:00")
class TestRateLimitAdminEndpoint:
"""Test GET /admin/rate-limits."""
def test_admin_can_access(self, client):
"""Admin users should be able to access the rate-limits endpoint."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
assert "rate_limits" in data
assert isinstance(data["rate_limits"], list)
finally:
app.dependency_overrides.clear()
def test_non_admin_rejected(self, client):
"""Non-admin users should get 403."""
# Without overriding the dependency, it should fail auth
response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403)
def test_returns_configured_endpoints(self, client):
"""Should list all rate-limited endpoints."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
endpoints = [rl["endpoint"] for rl in data["rate_limits"]]
assert "/auth/register" in endpoints
assert "/auth/login" in endpoints
finally:
app.dependency_overrides.clear()
def test_empty_state_shows_zero_counts(self, client):
"""When no requests have been made, counts should be zero."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0
finally:
app.dependency_overrides.clear()
def test_tracks_requests(self, client):
"""After making requests, the stats should reflect them."""
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "192.168.1.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
assert login_stats["total_requests"] == 3
assert login_stats["rejected_requests"] == 1
finally:
app.dependency_overrides.clear()
def test_includes_limit_config(self, client):
"""Each endpoint entry should include the rate limit config string."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert "limit" in rl
assert isinstance(rl["limit"], str)
finally:
app.dependency_overrides.clear()