forked from 0xWheatyz/SPARC
Add POST /export/batch endpoint for multi-company ZIP download
Implements issue #1674: a new authenticated POST /export/batch endpoint that accepts a list of company names and an optional format (csv or pdf), compiles per-company exports into a ZIP archive using Python's zipfile module, and returns it as a streaming download. Key changes: - Extract `_fetch_company_rows`, `_build_company_csv`, `_build_company_pdf` helpers to eliminate duplication between the single-company endpoints and the new batch endpoint - Refactor `export_company_csv` and `export_company_pdf` to delegate to the new helpers - Add `BatchExportRequest` Pydantic model (companies list + format field) - Add `POST /export/batch` which iterates over companies, skips those with no data, writes per-company files into the ZIP, and always includes a `manifest.json` listing exported and skipped companies - Response header: `Content-Disposition: attachment; filename=sparc-export-<date>.zip` - 17 new tests covering: single company (CSV + PDF), multiple companies, all-missing, unauthenticated, invalid-token, manifest structure, input validation Closes leeworks-agents/SPARC#1674 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,373 @@
|
||||
"""Tests for POST /export/batch endpoint (issue #1674).
|
||||
|
||||
Covers:
|
||||
- Single company export (CSV + PDF)
|
||||
- Multiple company export
|
||||
- All-missing companies (every requested company is skipped)
|
||||
- Unauthenticated / invalid-token requests
|
||||
- Manifest content validation
|
||||
- Invalid format rejection
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.auth import create_access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
"""Mock database client for all tests in this module."""
|
||||
db = MagicMock()
|
||||
|
||||
# Auth: user always exists
|
||||
db.get_user_by_id.return_value = {
|
||||
"id": 1,
|
||||
"email": "user@test.com",
|
||||
"role": "user",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
# Default cursor mock (overridden per-test via side_effect or return_value)
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
db._mock_cursor = mock_cursor
|
||||
|
||||
with patch("SPARC.api.get_db_client", return_value=db), \
|
||||
patch("SPARC.auth.get_db_client", return_value=db):
|
||||
yield db
|
||||
|
||||
|
||||
def _auth_header():
|
||||
token = create_access_token(1, "user@test.com", "user")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _rows_for(company_name: str):
|
||||
"""Return a single sample row for the given company."""
|
||||
return [
|
||||
(
|
||||
company_name,
|
||||
"company_analysis",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
f"Strong patent portfolio for {company_name}.",
|
||||
datetime(2025, 6, 15, 10, 30, 0),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _open_zip(content: bytes) -> zipfile.ZipFile:
|
||||
"""Helper: wrap response bytes as a ZipFile."""
|
||||
return zipfile.ZipFile(io.BytesIO(content))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportAuth:
|
||||
"""Unauthenticated and invalid-token requests must be rejected."""
|
||||
|
||||
def test_unauthenticated_returns_401(self, client):
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_invalid_token_returns_401(self, client):
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers={"Authorization": "Bearer totally.invalid.token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single company
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportSingleCompany:
|
||||
"""POST /export/batch with a single company name."""
|
||||
|
||||
def test_single_company_csv_returns_zip(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/zip"
|
||||
assert "attachment" in response.headers["content-disposition"]
|
||||
assert "sparc-export-" in response.headers["content-disposition"]
|
||||
assert response.headers["content-disposition"].endswith('.zip"')
|
||||
|
||||
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
names = zf.namelist()
|
||||
csv_files = [n for n in names if n.endswith(".csv")]
|
||||
assert len(csv_files) == 1
|
||||
assert "nvidia" in csv_files[0]
|
||||
|
||||
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
|
||||
csv_text = zf.read(csv_name).decode("utf-8")
|
||||
lines = csv_text.strip().split("\n")
|
||||
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
|
||||
assert "NVIDIA" in lines[1]
|
||||
|
||||
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "pdf"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
zf = _open_zip(response.content)
|
||||
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
|
||||
assert len(pdf_files) == 1
|
||||
# Verify it is actually a PDF (starts with %PDF)
|
||||
pdf_bytes = zf.read(pdf_files[0])
|
||||
assert pdf_bytes[:4] == b"%PDF"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multiple companies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportMultipleCompanies:
|
||||
"""POST /export/batch with several companies."""
|
||||
|
||||
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
|
||||
companies = ["NVIDIA", "Intel", "AMD"]
|
||||
mock_db._mock_cursor.fetchall.side_effect = [
|
||||
_rows_for("NVIDIA"),
|
||||
_rows_for("Intel"),
|
||||
_rows_for("AMD"),
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": companies, "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
zf = _open_zip(response.content)
|
||||
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
|
||||
assert len(csv_files) == 3
|
||||
|
||||
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
|
||||
companies = ["NVIDIA", "Intel"]
|
||||
mock_db._mock_cursor.fetchall.side_effect = [
|
||||
_rows_for("NVIDIA"),
|
||||
_rows_for("Intel"),
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": companies, "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
|
||||
assert manifest["skipped"] == []
|
||||
assert manifest["format"] == "csv"
|
||||
|
||||
def test_partial_missing_companies_skipped(self, client, mock_db):
|
||||
"""Companies with no data are skipped; others are exported."""
|
||||
mock_db._mock_cursor.fetchall.side_effect = [
|
||||
_rows_for("NVIDIA"),
|
||||
[], # no data for "UnknownCo"
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
zf = _open_zip(response.content)
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["exported"] == ["NVIDIA"]
|
||||
assert manifest["skipped"] == ["UnknownCo"]
|
||||
|
||||
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
|
||||
assert len(csv_files) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# All-missing companies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportAllMissing:
|
||||
"""When every requested company has no data, the ZIP still returns 200
|
||||
with only a manifest (no per-company files, all listed in skipped)."""
|
||||
|
||||
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = []
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
zf = _open_zip(response.content)
|
||||
assert "manifest.json" in zf.namelist()
|
||||
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["exported"] == []
|
||||
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
|
||||
|
||||
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = []
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["GhostCo"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
data_files = [n for n in zf.namelist() if n != "manifest.json"]
|
||||
assert data_files == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manifest validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportManifest:
|
||||
"""The manifest.json inside every ZIP must be well-formed."""
|
||||
|
||||
def test_manifest_always_present(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
assert "manifest.json" in zf.namelist()
|
||||
|
||||
def test_manifest_contains_required_keys(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert "export_date" in manifest
|
||||
assert "format" in manifest
|
||||
assert "exported" in manifest
|
||||
assert "skipped" in manifest
|
||||
|
||||
def test_manifest_format_field_matches_request(self, client, mock_db):
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "pdf"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
zf = _open_zip(response.content)
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["format"] == "pdf"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchExportInputValidation:
|
||||
"""Invalid request bodies must return 422."""
|
||||
|
||||
def test_invalid_format_returns_422(self, client):
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"], "format": "xlsx"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_empty_companies_list_returns_422(self, client):
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": [], "format": "csv"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_default_format_is_csv(self, client, mock_db):
|
||||
"""Omitting `format` should default to CSV."""
|
||||
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
||||
|
||||
response = client.post(
|
||||
"/export/batch",
|
||||
json={"companies": ["NVIDIA"]},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
zf = _open_zip(response.content)
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["format"] == "csv"
|
||||
Reference in New Issue
Block a user