forked from 0xWheatyz/SPARC
8f40109272
Implements issue #1674: a new authenticated POST /export/batch endpoint that accepts a list of company names and an optional format (csv or pdf), compiles per-company exports into a ZIP archive using Python's zipfile module, and returns it as a streaming download. Key changes: - Extract `_fetch_company_rows`, `_build_company_csv`, `_build_company_pdf` helpers to eliminate duplication between the single-company endpoints and the new batch endpoint - Refactor `export_company_csv` and `export_company_pdf` to delegate to the new helpers - Add `BatchExportRequest` Pydantic model (companies list + format field) - Add `POST /export/batch` which iterates over companies, skips those with no data, writes per-company files into the ZIP, and always includes a `manifest.json` listing exported and skipped companies - Response header: `Content-Disposition: attachment; filename=sparc-export-<date>.zip` - 17 new tests covering: single company (CSV + PDF), multiple companies, all-missing, unauthenticated, invalid-token, manifest structure, input validation Closes leeworks-agents/SPARC#1674 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
374 lines
12 KiB
Python
374 lines
12 KiB
Python
"""Tests for POST /export/batch endpoint (issue #1674).
|
|
|
|
Covers:
|
|
- Single company export (CSV + PDF)
|
|
- Multiple company export
|
|
- All-missing companies (every requested company is skipped)
|
|
- Unauthenticated / invalid-token requests
|
|
- Manifest content validation
|
|
- Invalid format rejection
|
|
"""
|
|
|
|
import io
|
|
import json
|
|
import zipfile
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from SPARC.api import app
|
|
from SPARC.auth import create_access_token
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create a FastAPI test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_db():
|
|
"""Mock database client for all tests in this module."""
|
|
db = MagicMock()
|
|
|
|
# Auth: user always exists
|
|
db.get_user_by_id.return_value = {
|
|
"id": 1,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
# Default cursor mock (overridden per-test via side_effect or return_value)
|
|
mock_cursor = MagicMock()
|
|
mock_conn = MagicMock()
|
|
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
|
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
|
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
|
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
|
|
db._mock_cursor = mock_cursor
|
|
|
|
with patch("SPARC.api.get_db_client", return_value=db), \
|
|
patch("SPARC.auth.get_db_client", return_value=db):
|
|
yield db
|
|
|
|
|
|
def _auth_header():
|
|
token = create_access_token(1, "user@test.com", "user")
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
def _rows_for(company_name: str):
|
|
"""Return a single sample row for the given company."""
|
|
return [
|
|
(
|
|
company_name,
|
|
"company_analysis",
|
|
"anthropic/claude-3.5-sonnet",
|
|
f"Strong patent portfolio for {company_name}.",
|
|
datetime(2025, 6, 15, 10, 30, 0),
|
|
)
|
|
]
|
|
|
|
|
|
def _open_zip(content: bytes) -> zipfile.ZipFile:
|
|
"""Helper: wrap response bytes as a ZipFile."""
|
|
return zipfile.ZipFile(io.BytesIO(content))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Authentication
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportAuth:
|
|
"""Unauthenticated and invalid-token requests must be rejected."""
|
|
|
|
def test_unauthenticated_returns_401(self, client):
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
)
|
|
assert response.status_code == 401
|
|
|
|
def test_invalid_token_returns_401(self, client):
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers={"Authorization": "Bearer totally.invalid.token"},
|
|
)
|
|
assert response.status_code == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Single company
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportSingleCompany:
|
|
"""POST /export/batch with a single company name."""
|
|
|
|
def test_single_company_csv_returns_zip(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "application/zip"
|
|
assert "attachment" in response.headers["content-disposition"]
|
|
assert "sparc-export-" in response.headers["content-disposition"]
|
|
assert response.headers["content-disposition"].endswith('.zip"')
|
|
|
|
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
names = zf.namelist()
|
|
csv_files = [n for n in names if n.endswith(".csv")]
|
|
assert len(csv_files) == 1
|
|
assert "nvidia" in csv_files[0]
|
|
|
|
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
|
|
csv_text = zf.read(csv_name).decode("utf-8")
|
|
lines = csv_text.strip().split("\n")
|
|
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
|
|
assert "NVIDIA" in lines[1]
|
|
|
|
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "pdf"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
zf = _open_zip(response.content)
|
|
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
|
|
assert len(pdf_files) == 1
|
|
# Verify it is actually a PDF (starts with %PDF)
|
|
pdf_bytes = zf.read(pdf_files[0])
|
|
assert pdf_bytes[:4] == b"%PDF"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Multiple companies
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportMultipleCompanies:
|
|
"""POST /export/batch with several companies."""
|
|
|
|
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
|
|
companies = ["NVIDIA", "Intel", "AMD"]
|
|
mock_db._mock_cursor.fetchall.side_effect = [
|
|
_rows_for("NVIDIA"),
|
|
_rows_for("Intel"),
|
|
_rows_for("AMD"),
|
|
]
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": companies, "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
zf = _open_zip(response.content)
|
|
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
|
|
assert len(csv_files) == 3
|
|
|
|
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
|
|
companies = ["NVIDIA", "Intel"]
|
|
mock_db._mock_cursor.fetchall.side_effect = [
|
|
_rows_for("NVIDIA"),
|
|
_rows_for("Intel"),
|
|
]
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": companies, "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
|
|
assert manifest["skipped"] == []
|
|
assert manifest["format"] == "csv"
|
|
|
|
def test_partial_missing_companies_skipped(self, client, mock_db):
|
|
"""Companies with no data are skipped; others are exported."""
|
|
mock_db._mock_cursor.fetchall.side_effect = [
|
|
_rows_for("NVIDIA"),
|
|
[], # no data for "UnknownCo"
|
|
]
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
zf = _open_zip(response.content)
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert manifest["exported"] == ["NVIDIA"]
|
|
assert manifest["skipped"] == ["UnknownCo"]
|
|
|
|
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
|
|
assert len(csv_files) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# All-missing companies
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportAllMissing:
|
|
"""When every requested company has no data, the ZIP still returns 200
|
|
with only a manifest (no per-company files, all listed in skipped)."""
|
|
|
|
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = []
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
zf = _open_zip(response.content)
|
|
assert "manifest.json" in zf.namelist()
|
|
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert manifest["exported"] == []
|
|
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
|
|
|
|
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = []
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["GhostCo"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
data_files = [n for n in zf.namelist() if n != "manifest.json"]
|
|
assert data_files == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Manifest validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportManifest:
|
|
"""The manifest.json inside every ZIP must be well-formed."""
|
|
|
|
def test_manifest_always_present(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
assert "manifest.json" in zf.namelist()
|
|
|
|
def test_manifest_contains_required_keys(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert "export_date" in manifest
|
|
assert "format" in manifest
|
|
assert "exported" in manifest
|
|
assert "skipped" in manifest
|
|
|
|
def test_manifest_format_field_matches_request(self, client, mock_db):
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "pdf"},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
zf = _open_zip(response.content)
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert manifest["format"] == "pdf"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Input validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchExportInputValidation:
|
|
"""Invalid request bodies must return 422."""
|
|
|
|
def test_invalid_format_returns_422(self, client):
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"], "format": "xlsx"},
|
|
headers=_auth_header(),
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_empty_companies_list_returns_422(self, client):
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": [], "format": "csv"},
|
|
headers=_auth_header(),
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_default_format_is_csv(self, client, mock_db):
|
|
"""Omitting `format` should default to CSV."""
|
|
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
|
|
|
|
response = client.post(
|
|
"/export/batch",
|
|
json={"companies": ["NVIDIA"]},
|
|
headers=_auth_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
zf = _open_zip(response.content)
|
|
manifest = json.loads(zf.read("manifest.json"))
|
|
assert manifest["format"] == "csv"
|