Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company 0e68e8c900 Add cursor-based pagination to /analyze/batch and /jobs endpoints
- Fix route ordering bug: GET /analyze/batch was shadowed by
  GET /analyze/{company_name} causing all GET requests to /analyze/batch
  to be erroneously handled as single-company analysis (503). Move
  /analyze/batch GET registration to before the {company_name} route.
- Update TypeScript schema.d.ts: add AnalysisRecord, PaginatedAnalysisResponse,
  PaginatedJobsResponse schemas; add GET /analyze/batch operation with
  cursor+limit+company_name params; update list_jobs_jobs_get to include
  cursor param and return PaginatedJobsResponse.
- Update frontend/src/api/client.ts: add listBatchAnalyses() method with
  cursor/limit support; update listJobs() to accept cursor and return
  PaginatedJobsResponse; default limit changed from 10 to 50.
- Update frontend/src/types/index.ts: export AnalysisRecord,
  PaginatedAnalysisResponse, PaginatedJobsResponse.
- Expand tests/test_pagination.py: add auth fixture so tests pass JWT
  validation; add 11 new /jobs tests covering first page, last page,
  subsequent pages, empty results, status filter, limit boundaries, cursor
  forwarding, and paginated response shape.

Closes leeworks-agents/SPARC#1684

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:34:18 +00:00
6 changed files with 472 additions and 600 deletions
+121 -200
View File
@@ -675,25 +675,27 @@ async def get_analytics_trends(
# ============== Export Endpoints ==============
class BatchExportRequest(BaseModel):
"""Request model for batch ZIP export of analysis results."""
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
companies: list[CompanyName] = Field(
..., min_length=1, max_length=50, description="List of company names to export"
)
format: str = Field(
default="csv",
pattern="^(csv|pdf)$",
description="Export format: 'csv' or 'pdf'",
)
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
def _fetch_company_rows(db, company_name: str) -> list:
"""Fetch all non-cached analysis rows for *company_name* from the DB.
Returns a list of tuples: (company_name, analysis_type, model, response, timestamp).
Returns an empty list when no results exist.
Returns:
CSV file download
"""
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
@@ -705,24 +707,43 @@ def _fetch_company_rows(db, company_name: str) -> list:
""",
(company_name,),
)
return cur.fetchall()
rows = cur.fetchall()
def _build_company_csv(rows) -> bytes:
"""Render *rows* as CSV bytes."""
import csv
import io
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows:
writer.writerow(row)
return output.getvalue().encode("utf-8")
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
def _build_company_pdf(rows, company_name: str) -> bytes:
"""Render *rows* as PDF bytes using reportlab."""
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io
from reportlab.lib import colors
@@ -737,6 +758,23 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
TableStyle,
)
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
@@ -779,11 +817,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements = []
display_name = rows[0][0]
# Title and date
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [
["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
@@ -805,6 +845,7 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements.append(summary_table)
elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
@@ -816,11 +857,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
)
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = (
response.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"):
if line.strip():
elements.append(Paragraph(line, body_style))
@@ -831,133 +874,11 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
doc.build(elements)
buffer.seek(0)
return buffer.getvalue()
@app.post("/export/batch", tags=["Export"])
async def export_batch_zip(
request: BatchExportRequest,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for multiple companies as a ZIP archive.
For each company in the request, fetches all stored analysis records and
adds a per-company file (CSV or PDF) to the archive. Companies with no
stored results are skipped; a ``manifest.json`` inside the ZIP lists both
the exported and skipped companies.
Args:
request: List of company names and desired export format ('csv' or 'pdf')
Returns:
ZIP archive download containing one file per found company plus a manifest
"""
import io
import json
import zipfile
db = get_db_client()
export_date = datetime.now().strftime("%Y-%m-%d")
fmt = request.format
exported: list[str] = []
skipped: list[str] = []
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for company_name in request.companies:
rows = _fetch_company_rows(db, company_name)
if not rows:
skipped.append(company_name)
continue
safe_name = company_name.replace(" ", "_").lower()
if fmt == "pdf":
file_bytes = _build_company_pdf(rows, company_name)
filename = f"{safe_name}-analysis-{export_date}.pdf"
else:
file_bytes = _build_company_csv(rows)
filename = f"sparc_{safe_name}_export.csv"
zf.writestr(filename, file_bytes)
exported.append(company_name)
# Always include a manifest
manifest = {
"export_date": export_date,
"format": fmt,
"exported": exported,
"skipped": skipped,
}
zf.writestr("manifest.json", json.dumps(manifest, indent=2))
zip_buffer.seek(0)
zip_filename = f"sparc-export-{export_date}.zip"
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{zip_filename}"'},
)
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([_build_company_csv(rows)]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
analysis_date = datetime.now().strftime("%Y-%m-%d")
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse(
iter([_build_company_pdf(rows, company_name)]),
iter([buffer.getvalue()]),
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@@ -976,6 +897,58 @@ async def health_check():
)
@app.get(
"/analyze/batch",
response_model=PaginatedAnalysisResponse,
tags=["Analysis"],
)
async def list_analysis_results(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List stored analysis results with cursor-based pagination.
Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
parameter to fetch the next page. When ``next_cursor`` is ``null``,
there are no more results.
Args:
company_name: Optional filter by company name
limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of analysis results
"""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
@@ -1038,58 +1011,6 @@ async def analyze_single_patent(
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/batch",
response_model=PaginatedAnalysisResponse,
tags=["Analysis"],
)
async def list_analysis_results(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List stored analysis results with cursor-based pagination.
Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
parameter to fetch the next page. When ``next_cursor`` is ``null``,
there are no more results.
Args:
company_name: Optional filter by company name
limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of analysis results
"""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
+79 -3
View File
@@ -1,5 +1,5 @@
import axios, { AxiosError, InternalAxiosRequestConfig } from 'axios';
import type { TokenResponse, User, CompanyAnalysis, BatchAnalysisResult, JobStatus, Analytics } from '../types';
import type { TokenResponse, User, CompanyAnalysis, BatchAnalysisResult, JobStatus, Analytics, PaginatedJobsResponse, PaginatedAnalysisResponse } from '../types';
const API_BASE_URL = import.meta.env.VITE_API_URL || '/api';
@@ -141,15 +141,60 @@ export const analysisApi = {
return response.data;
},
listJobs: async (status?: string, limit = 10): Promise<JobStatus[]> => {
listJobs: async (status?: string, limit = 50, cursor?: string): Promise<PaginatedJobsResponse> => {
const params = new URLSearchParams();
if (status) params.append('status', status);
params.append('limit', limit.toString());
const response = await api.get<JobStatus[]>(`/jobs?${params}`);
if (cursor) params.append('cursor', cursor);
const response = await api.get<PaginatedJobsResponse>(`/jobs?${params}`);
return response.data;
},
listBatchAnalyses: async (companyName?: string, limit = 50, cursor?: string): Promise<PaginatedAnalysisResponse> => {
const params = new URLSearchParams();
if (companyName) params.append('company_name', companyName);
params.append('limit', limit.toString());
if (cursor) params.append('cursor', cursor);
const response = await api.get<PaginatedAnalysisResponse>(`/analyze/batch?${params}`);
return response.data;
},
getCompanyHistory: async (companyName: string, limit = 20): Promise<AnalysisHistoryItem[]> => {
const response = await api.get<AnalysisHistoryItem[]>(
`/analyze/${encodeURIComponent(companyName)}/history?limit=${limit}`
);
return response.data;
},
diffAnalyses: async (companyName: string, fromId: number, toId: number): Promise<AnalysisDiff> => {
const response = await api.get<AnalysisDiff>(
`/analyze/${encodeURIComponent(companyName)}/diff?from=${fromId}&to=${toId}`
);
return response.data;
},
};
// Analysis diff types
export interface AnalysisHistoryItem {
id: number;
analysis_type: string | null;
model: string | null;
timestamp: string;
}
export interface AnalysisDiff {
company_name: string;
from_id: number;
to_id: number;
from_timestamp: string;
to_timestamp: string;
patent_count_delta: number;
added_patents: string[];
removed_patents: string[];
changed_fields: Record<string, { from: string | null; to: string | null }>;
summary: string;
}
// Export API
export const exportApi = {
exportCsv: async (companyName: string): Promise<void> => {
@@ -201,6 +246,32 @@ export const analyticsApi = {
},
};
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API
export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -216,6 +287,11 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`);
},
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
};
export default api;
+96 -5
View File
@@ -222,7 +222,17 @@ export interface paths {
path?: never;
cookie?: never;
};
get?: never;
/**
* List Batch Analyses
* @description List stored analysis results with cursor-based pagination.
*
* Returns past analysis results ordered by timestamp descending. Use
* ``limit`` to control page size (default 50, max 200). The response
* includes a ``next_cursor`` field; pass it back as the ``cursor`` query
* parameter to fetch the next page. When ``next_cursor`` is ``null``,
* there are no more results.
*/
get: operations["list_batch_analyses_analyze_batch_get"];
put?: never;
/**
* Analyze Companies Batch
@@ -308,14 +318,15 @@ export interface paths {
};
/**
* List Jobs
* @description List all analysis jobs.
* @description List analysis jobs with cursor-based pagination.
*
* Args:
* status: Optional filter by job status
* limit: Maximum number of jobs to return (default 10, max 100)
* limit: Maximum number of jobs to return (default 50, max 200)
* cursor: Opaque cursor from a previous response's next_cursor field
*
* Returns:
* List of job statuses
* Paginated list of job statuses with next_cursor for subsequent pages
*/
get: operations["list_jobs_jobs_get"];
put?: never;
@@ -330,6 +341,27 @@ export interface paths {
export type webhooks = Record<string, never>;
export interface components {
schemas: {
/**
* AnalysisRecord
* @description A single stored analysis result.
*/
AnalysisRecord: {
/** Id */
id: number;
/** Company Name */
company_name?: string | null;
/** Analysis Type */
analysis_type?: string | null;
/** Model */
model?: string | null;
/** Response */
response?: string | null;
/**
* Timestamp
* Format: date-time
*/
timestamp?: string | null;
};
/**
* AnalyticsResponse
* @description Analytics response model.
@@ -425,6 +457,26 @@ export interface components {
*/
timestamp: string;
};
/**
* PaginatedAnalysisResponse
* @description Paginated response for analysis result listings.
*/
PaginatedAnalysisResponse: {
/** Items */
items: components["schemas"]["AnalysisRecord"][];
/** Next Cursor */
next_cursor?: string | null;
};
/**
* PaginatedJobsResponse
* @description Paginated response for job listings.
*/
PaginatedJobsResponse: {
/** Items */
items: components["schemas"]["JobStatus"][];
/** Next Cursor */
next_cursor?: string | null;
};
/**
* JobStatus
* @description Status of a background analysis job.
@@ -944,7 +996,10 @@ export interface operations {
query?: {
/** @description Filter by status: pending, running, completed, failed */
status?: string | null;
/** @description Maximum number of jobs to return (default 50, max 200) */
limit?: number;
/** @description Opaque cursor from a previous response's next_cursor field */
cursor?: string | null;
};
header?: never;
path?: never;
@@ -958,7 +1013,43 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JobStatus"][];
"application/json": components["schemas"]["PaginatedJobsResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
list_batch_analyses_analyze_batch_get: {
parameters: {
query?: {
/** @description Filter results by company name */
company_name?: string | null;
/** @description Maximum number of results to return (default 50, max 200) */
limit?: number;
/** @description Opaque cursor from a previous response's next_cursor field */
cursor?: string | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["PaginatedAnalysisResponse"];
};
};
/** @description Validation Error */
+5
View File
@@ -30,3 +30,8 @@ export type HealthResponse = components['schemas']['HealthResponse'];
export type BatchAnalysisRequest = components['schemas']['BatchAnalysisRequest'];
export type ValidationError = components['schemas']['ValidationError'];
export type HTTPValidationError = components['schemas']['HTTPValidationError'];
// Pagination types
export type AnalysisRecord = components['schemas']['AnalysisRecord'];
export type PaginatedAnalysisResponse = components['schemas']['PaginatedAnalysisResponse'];
export type PaginatedJobsResponse = components['schemas']['PaginatedJobsResponse'];
-373
View File
@@ -1,373 +0,0 @@
"""Tests for POST /export/batch endpoint (issue #1674).
Covers:
- Single company export (CSV + PDF)
- Multiple company export
- All-missing companies (every requested company is skipped)
- Unauthenticated / invalid-token requests
- Manifest content validation
- Invalid format rejection
"""
import io
import json
import zipfile
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create a FastAPI test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock database client for all tests in this module."""
db = MagicMock()
# Auth: user always exists
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
# Default cursor mock (overridden per-test via side_effect or return_value)
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
db._mock_cursor = mock_cursor
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _rows_for(company_name: str):
"""Return a single sample row for the given company."""
return [
(
company_name,
"company_analysis",
"anthropic/claude-3.5-sonnet",
f"Strong patent portfolio for {company_name}.",
datetime(2025, 6, 15, 10, 30, 0),
)
]
def _open_zip(content: bytes) -> zipfile.ZipFile:
"""Helper: wrap response bytes as a ZipFile."""
return zipfile.ZipFile(io.BytesIO(content))
# ---------------------------------------------------------------------------
# Authentication
# ---------------------------------------------------------------------------
class TestBatchExportAuth:
"""Unauthenticated and invalid-token requests must be rejected."""
def test_unauthenticated_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
)
assert response.status_code == 401
def test_invalid_token_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers={"Authorization": "Bearer totally.invalid.token"},
)
assert response.status_code == 401
# ---------------------------------------------------------------------------
# Single company
# ---------------------------------------------------------------------------
class TestBatchExportSingleCompany:
"""POST /export/batch with a single company name."""
def test_single_company_csv_returns_zip(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/zip"
assert "attachment" in response.headers["content-disposition"]
assert "sparc-export-" in response.headers["content-disposition"]
assert response.headers["content-disposition"].endswith('.zip"')
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
names = zf.namelist()
csv_files = [n for n in names if n.endswith(".csv")]
assert len(csv_files) == 1
assert "nvidia" in csv_files[0]
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
csv_text = zf.read(csv_name).decode("utf-8")
lines = csv_text.strip().split("\n")
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
assert "NVIDIA" in lines[1]
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
assert len(pdf_files) == 1
# Verify it is actually a PDF (starts with %PDF)
pdf_bytes = zf.read(pdf_files[0])
assert pdf_bytes[:4] == b"%PDF"
# ---------------------------------------------------------------------------
# Multiple companies
# ---------------------------------------------------------------------------
class TestBatchExportMultipleCompanies:
"""POST /export/batch with several companies."""
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
companies = ["NVIDIA", "Intel", "AMD"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
_rows_for("AMD"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 3
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
companies = ["NVIDIA", "Intel"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
assert manifest["skipped"] == []
assert manifest["format"] == "csv"
def test_partial_missing_companies_skipped(self, client, mock_db):
"""Companies with no data are skipped; others are exported."""
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
[], # no data for "UnknownCo"
]
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == ["NVIDIA"]
assert manifest["skipped"] == ["UnknownCo"]
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 1
# ---------------------------------------------------------------------------
# All-missing companies
# ---------------------------------------------------------------------------
class TestBatchExportAllMissing:
"""When every requested company has no data, the ZIP still returns 200
with only a manifest (no per-company files, all listed in skipped)."""
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == []
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
data_files = [n for n in zf.namelist() if n != "manifest.json"]
assert data_files == []
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
class TestBatchExportManifest:
"""The manifest.json inside every ZIP must be well-formed."""
def test_manifest_always_present(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
def test_manifest_contains_required_keys(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert "export_date" in manifest
assert "format" in manifest
assert "exported" in manifest
assert "skipped" in manifest
def test_manifest_format_field_matches_request(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "pdf"
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
class TestBatchExportInputValidation:
"""Invalid request bodies must return 422."""
def test_invalid_format_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "xlsx"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_empty_companies_list_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": [], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_default_format_is_csv(self, client, mock_db):
"""Omitting `format` should default to CSV."""
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"]},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "csv"
+171 -19
View File
@@ -1,12 +1,13 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
@@ -15,6 +16,27 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_auth_db():
"""Mock the auth DB so JWT token validation succeeds without a real database."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
"""Create a Bearer auth header for a regular user."""
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
@@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination:
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10")
response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
@@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3")
response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
@@ -84,11 +106,14 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42")
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
(call_kwargs[1].get("cursor") == "2025-01-01T00:00:00|42" if len(call_kwargs) > 1 else False)
cursor_val = (
call_kwargs.kwargs.get("cursor")
or (call_kwargs[1].get("cursor") if len(call_kwargs) > 1 else None)
)
assert cursor_val == "2025-01-01T00:00:00|42"
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
@@ -97,19 +122,19 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch")
client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201")
response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0")
response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -119,10 +144,13 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel")
client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
company_val = (
call_kwargs.kwargs.get("company_name")
or (call_kwargs[1].get("company_name") if len(call_kwargs) > 1 else None)
)
assert company_val == "intel"
@patch("SPARC.api._get_job_db")
def test_empty_result_set(self, mock_get_db, client):
@@ -131,15 +159,30 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch")
response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["next_cursor"] is None
@patch("SPARC.api._get_job_db")
def test_subsequent_page_uses_cursor(self, mock_get_db, client):
"""Passing a cursor should retrieve the next page of results."""
db = Mock()
db.list_analyses.return_value = [_make_analysis_row(99, minutes_ago=100)]
mock_get_db.return_value = db
class TestJobsPaginationDefaults:
"""Test that /jobs endpoint uses updated defaults."""
cursor = "2025-06-01T12:00:00|50"
response = client.get(f"/analyze/batch?limit=10&cursor={cursor}", headers=_auth_header())
assert response.status_code == 200
data = response.json()
# Only one item returned → last page → no next cursor
assert len(data["items"]) == 1
assert data["next_cursor"] is None
class TestJobsPagination:
"""Test cursor-based pagination on GET /jobs."""
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
@@ -148,14 +191,19 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs")
client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201")
response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/jobs?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -165,5 +213,109 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200")
response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200
@patch("SPARC.api._get_job_db")
def test_first_page_returns_items_and_cursor(self, mock_get_db, client):
"""First page with more results than limit should return next_cursor."""
db = Mock()
# Return limit+1 rows to simulate more data available
rows = [_make_job_row(f"job-{i}", minutes_ago=i) for i in range(4)]
db.list_jobs.return_value = rows
mock_get_db.return_value = db
response = client.get("/jobs?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["next_cursor"] is not None
@patch("SPARC.api._get_job_db")
def test_last_page_returns_no_cursor(self, mock_get_db, client):
"""When fewer results than limit, next_cursor should be null (last page)."""
db = Mock()
rows = [
_make_job_row("job-a", minutes_ago=5),
_make_job_row("job-b", minutes_ago=10),
]
db.list_jobs.return_value = rows
mock_get_db.return_value = db
response = client.get("/jobs?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
assert data["next_cursor"] is None
@patch("SPARC.api._get_job_db")
def test_cursor_forwarded_to_db(self, mock_get_db, client):
"""The cursor query param should be forwarded to the database layer."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs?cursor=2025-01-01T00:00:00|job-99", headers=_auth_header())
db.list_jobs.assert_called_once()
call_kwargs = db.list_jobs.call_args
cursor_val = (
call_kwargs.kwargs.get("cursor")
or (call_kwargs[1].get("cursor") if len(call_kwargs) > 1 else None)
)
assert cursor_val == "2025-01-01T00:00:00|job-99"
@patch("SPARC.api._get_job_db")
def test_empty_result_set(self, mock_get_db, client):
"""Empty result set returns empty items list and null next_cursor."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["next_cursor"] is None
@patch("SPARC.api._get_job_db")
def test_status_filter_forwarded(self, mock_get_db, client):
"""The status filter should be forwarded to the database layer."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs?status=completed", headers=_auth_header())
db.list_jobs.assert_called_once()
call_kwargs = db.list_jobs.call_args
status_val = (
call_kwargs.kwargs.get("status")
or (call_kwargs[1].get("status") if len(call_kwargs) > 1 else None)
)
assert status_val == "completed"
@patch("SPARC.api._get_job_db")
def test_response_has_paginated_shape(self, mock_get_db, client):
"""Response must have 'items' and 'next_cursor' fields (paginated shape)."""
db = Mock()
db.list_jobs.return_value = [_make_job_row("job-x")]
mock_get_db.return_value = db
response = client.get("/jobs?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "next_cursor" in data
@patch("SPARC.api._get_job_db")
def test_subsequent_page_uses_cursor(self, mock_get_db, client):
"""Passing cursor returns the next page; last page has null next_cursor."""
db = Mock()
db.list_jobs.return_value = [_make_job_row("job-last", minutes_ago=200)]
mock_get_db.return_value = db
cursor = "2025-06-01T12:00:00|job-50"
response = client.get(f"/jobs?limit=10&cursor={cursor}", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 1
assert data["next_cursor"] is None