"""Tests for POST /export/batch endpoint (issue #1674). Covers: - Single company export (CSV + PDF) - Multiple company export - All-missing companies (every requested company is skipped) - Unauthenticated / invalid-token requests - Manifest content validation - Invalid format rejection """ import io import json import zipfile from datetime import datetime, timezone from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app from SPARC.auth import create_access_token @pytest.fixture def client(): """Create a FastAPI test client.""" return TestClient(app) @pytest.fixture(autouse=True) def mock_db(): """Mock database client for all tests in this module.""" db = MagicMock() # Auth: user always exists db.get_user_by_id.return_value = { "id": 1, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } # Default cursor mock (overridden per-test via side_effect or return_value) mock_cursor = MagicMock() mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn) db.get_conn.return_value.__exit__ = MagicMock(return_value=False) db._mock_cursor = mock_cursor with patch("SPARC.api.get_db_client", return_value=db), \ patch("SPARC.auth.get_db_client", return_value=db): yield db def _auth_header(): token = create_access_token(1, "user@test.com", "user") return {"Authorization": f"Bearer {token}"} def _rows_for(company_name: str): """Return a single sample row for the given company.""" return [ ( company_name, "company_analysis", "anthropic/claude-3.5-sonnet", f"Strong patent portfolio for {company_name}.", datetime(2025, 6, 15, 10, 30, 0), ) ] def _open_zip(content: bytes) -> zipfile.ZipFile: """Helper: wrap response bytes as a ZipFile.""" return zipfile.ZipFile(io.BytesIO(content)) # --------------------------------------------------------------------------- # Authentication # --------------------------------------------------------------------------- class TestBatchExportAuth: """Unauthenticated and invalid-token requests must be rejected.""" def test_unauthenticated_returns_401(self, client): response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, ) assert response.status_code == 401 def test_invalid_token_returns_401(self, client): response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers={"Authorization": "Bearer totally.invalid.token"}, ) assert response.status_code == 401 # --------------------------------------------------------------------------- # Single company # --------------------------------------------------------------------------- class TestBatchExportSingleCompany: """POST /export/batch with a single company name.""" def test_single_company_csv_returns_zip(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers=_auth_header(), ) assert response.status_code == 200 assert response.headers["content-type"] == "application/zip" assert "attachment" in response.headers["content-disposition"] assert "sparc-export-" in response.headers["content-disposition"] assert response.headers["content-disposition"].endswith('.zip"') def test_single_company_csv_zip_contains_csv_file(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) names = zf.namelist() csv_files = [n for n in names if n.endswith(".csv")] assert len(csv_files) == 1 assert "nvidia" in csv_files[0] def test_single_company_csv_content_is_valid_csv(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0] csv_text = zf.read(csv_name).decode("utf-8") lines = csv_text.strip().split("\n") assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp" assert "NVIDIA" in lines[1] def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "pdf"}, headers=_auth_header(), ) assert response.status_code == 200 zf = _open_zip(response.content) pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")] assert len(pdf_files) == 1 # Verify it is actually a PDF (starts with %PDF) pdf_bytes = zf.read(pdf_files[0]) assert pdf_bytes[:4] == b"%PDF" # --------------------------------------------------------------------------- # Multiple companies # --------------------------------------------------------------------------- class TestBatchExportMultipleCompanies: """POST /export/batch with several companies.""" def test_multiple_companies_each_gets_a_file(self, client, mock_db): companies = ["NVIDIA", "Intel", "AMD"] mock_db._mock_cursor.fetchall.side_effect = [ _rows_for("NVIDIA"), _rows_for("Intel"), _rows_for("AMD"), ] response = client.post( "/export/batch", json={"companies": companies, "format": "csv"}, headers=_auth_header(), ) assert response.status_code == 200 zf = _open_zip(response.content) csv_files = [n for n in zf.namelist() if n.endswith(".csv")] assert len(csv_files) == 3 def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db): companies = ["NVIDIA", "Intel"] mock_db._mock_cursor.fetchall.side_effect = [ _rows_for("NVIDIA"), _rows_for("Intel"), ] response = client.post( "/export/batch", json={"companies": companies, "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) manifest = json.loads(zf.read("manifest.json")) assert set(manifest["exported"]) == {"NVIDIA", "Intel"} assert manifest["skipped"] == [] assert manifest["format"] == "csv" def test_partial_missing_companies_skipped(self, client, mock_db): """Companies with no data are skipped; others are exported.""" mock_db._mock_cursor.fetchall.side_effect = [ _rows_for("NVIDIA"), [], # no data for "UnknownCo" ] response = client.post( "/export/batch", json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"}, headers=_auth_header(), ) assert response.status_code == 200 zf = _open_zip(response.content) manifest = json.loads(zf.read("manifest.json")) assert manifest["exported"] == ["NVIDIA"] assert manifest["skipped"] == ["UnknownCo"] csv_files = [n for n in zf.namelist() if n.endswith(".csv")] assert len(csv_files) == 1 # --------------------------------------------------------------------------- # All-missing companies # --------------------------------------------------------------------------- class TestBatchExportAllMissing: """When every requested company has no data, the ZIP still returns 200 with only a manifest (no per-company files, all listed in skipped).""" def test_all_missing_returns_200_with_manifest_only(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = [] response = client.post( "/export/batch", json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"}, headers=_auth_header(), ) assert response.status_code == 200 zf = _open_zip(response.content) assert "manifest.json" in zf.namelist() manifest = json.loads(zf.read("manifest.json")) assert manifest["exported"] == [] assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"} def test_all_missing_zip_has_no_data_files(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = [] response = client.post( "/export/batch", json={"companies": ["GhostCo"], "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) data_files = [n for n in zf.namelist() if n != "manifest.json"] assert data_files == [] # --------------------------------------------------------------------------- # Manifest validation # --------------------------------------------------------------------------- class TestBatchExportManifest: """The manifest.json inside every ZIP must be well-formed.""" def test_manifest_always_present(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) assert "manifest.json" in zf.namelist() def test_manifest_contains_required_keys(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "csv"}, headers=_auth_header(), ) zf = _open_zip(response.content) manifest = json.loads(zf.read("manifest.json")) assert "export_date" in manifest assert "format" in manifest assert "exported" in manifest assert "skipped" in manifest def test_manifest_format_field_matches_request(self, client, mock_db): mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "pdf"}, headers=_auth_header(), ) zf = _open_zip(response.content) manifest = json.loads(zf.read("manifest.json")) assert manifest["format"] == "pdf" # --------------------------------------------------------------------------- # Input validation # --------------------------------------------------------------------------- class TestBatchExportInputValidation: """Invalid request bodies must return 422.""" def test_invalid_format_returns_422(self, client): response = client.post( "/export/batch", json={"companies": ["NVIDIA"], "format": "xlsx"}, headers=_auth_header(), ) assert response.status_code == 422 def test_empty_companies_list_returns_422(self, client): response = client.post( "/export/batch", json={"companies": [], "format": "csv"}, headers=_auth_header(), ) assert response.status_code == 422 def test_default_format_is_csv(self, client, mock_db): """Omitting `format` should default to CSV.""" mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA") response = client.post( "/export/batch", json={"companies": ["NVIDIA"]}, headers=_auth_header(), ) assert response.status_code == 200 zf = _open_zip(response.content) manifest = json.loads(zf.read("manifest.json")) assert manifest["format"] == "csv"