"""Tests for FastAPI web service endpoints.""" from datetime import datetime, timezone from unittest.mock import Mock, MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app from SPARC.auth import create_access_token from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @pytest.fixture def client(): """Create test client.""" return TestClient(app) @pytest.fixture(autouse=True) def mock_db(): """Mock the database client used by auth endpoints.""" db = MagicMock() db.get_user_by_id.return_value = { "id": 1, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } with patch("SPARC.api.get_db_client", return_value=db), \ patch("SPARC.auth.get_db_client", return_value=db): yield db @pytest.fixture def mock_analyzer(mocker): """Mock the global analyzer.""" mock = Mock() mocker.patch("SPARC.api._analyzer", mock) return mock def _auth_header(user_id=1, email="user@test.com", role="user"): """Create an Authorization header with a valid access token.""" token = create_access_token(user_id, email, role) return {"Authorization": f"Bearer {token}"} class TestHealthEndpoint: """Test health check endpoint.""" def test_health_returns_ok(self, client): """Test health endpoint returns healthy status.""" response = client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" assert data["version"] == "1.0.0" assert "timestamp" in data class TestAnalyzeCompanyEndpoint: """Test single company analysis endpoint.""" def test_analyze_company_success(self, client, mock_analyzer): """Test successful company analysis.""" mock_result = CompanyAnalysisResult( company_name="nvidia", analysis="Strong AI patent portfolio", patent_count=5, success=True, timestamp=datetime.now(), ) mock_analyzer._analyze_company_safe.return_value = mock_result response = client.get("/analyze/nvidia", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert data["company_name"] == "nvidia" assert data["analysis"] == "Strong AI patent portfolio" assert data["patent_count"] == 5 assert data["success"] is True def test_analyze_company_failure(self, client, mock_analyzer): """Test company analysis with error.""" mock_result = CompanyAnalysisResult( company_name="unknown", analysis="", patent_count=0, success=False, error="No patents found", timestamp=datetime.now(), ) mock_analyzer._analyze_company_safe.return_value = mock_result response = client.get("/analyze/unknown", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert data["success"] is False assert data["error"] == "No patents found" class TestBatchAnalysisEndpoint: """Test batch analysis endpoint.""" def test_batch_analysis_success(self, client, mock_analyzer): """Test successful batch analysis.""" results = [ CompanyAnalysisResult( company_name="nvidia", analysis="Strong portfolio", patent_count=5, success=True, timestamp=datetime.now(), ), CompanyAnalysisResult( company_name="amd", analysis="Growing portfolio", patent_count=3, success=True, timestamp=datetime.now(), ), ] mock_batch = BatchAnalysisResult( results=results, total_companies=2, successful=2, failed=0, timestamp=datetime.now(), ) mock_analyzer.analyze_companies.return_value = mock_batch response = client.post( "/analyze/batch", json={"companies": ["nvidia", "amd"], "max_workers": 2}, headers=_auth_header(), ) assert response.status_code == 200 data = response.json() assert data["total_companies"] == 2 assert data["successful"] == 2 assert data["failed"] == 0 assert len(data["results"]) == 2 def test_batch_analysis_validation(self, client): """Test batch analysis request validation.""" # Empty companies list response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header()) assert response.status_code == 422 # Too many companies response = client.post( "/analyze/batch", json={"companies": [f"company{i}" for i in range(25)]}, headers=_auth_header(), ) assert response.status_code == 422 # Invalid max_workers response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "max_workers": 10}, headers=_auth_header(), ) assert response.status_code == 422 class TestAsyncBatchEndpoint: """Test async batch analysis endpoint.""" @patch("SPARC.api._get_job_db") def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer): """Test async endpoint creates a job with owner_id.""" job_db = MagicMock() job_db.create_job.return_value = { "job_id": "j1", "status": "pending", "progress": 0, "total_companies": 2, "completed_companies": 0, "result_json": None, "error": None, "owner_id": 1, } mock_get_db.return_value = job_db response = client.post( "/analyze/batch/async", json={"companies": ["nvidia", "amd"]}, headers=_auth_header(), ) assert response.status_code == 200 data = response.json() assert "job_id" in data assert data["status"] == "pending" assert data["total_companies"] == 2 assert data["progress"] == 0 # Verify owner_id was passed job_db.create_job.assert_called_once() assert job_db.create_job.call_args.kwargs.get("owner_id") == 1 class TestJobEndpoints: """Test job management endpoints.""" @patch("SPARC.api._get_job_db") def test_get_job_not_found(self, mock_get_db, client): """Test getting nonexistent job.""" job_db = MagicMock() job_db.get_job.return_value = None mock_get_db.return_value = job_db response = client.get("/jobs/nonexistent", headers=_auth_header()) assert response.status_code == 404 @patch("SPARC.api._get_job_db") def test_list_jobs(self, mock_get_db, client): """Test listing jobs.""" job_db = MagicMock() job_db.list_jobs.return_value = [] mock_get_db.return_value = job_db response = client.get("/jobs", headers=_auth_header()) assert response.status_code == 200 @patch("SPARC.api._get_job_db") def test_list_jobs_with_filter(self, mock_get_db, client): """Test listing jobs with status filter.""" job_db = MagicMock() job_db.list_jobs.return_value = [] mock_get_db.return_value = job_db response = client.get("/jobs?status=completed", headers=_auth_header()) assert response.status_code == 200 class TestModelValidation: """Test that unsupported model identifiers are rejected.""" def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): """GET /analyze/{company} with unsupported model returns 400.""" response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header()) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] def test_analyze_accepts_supported_model(self, client, mock_analyzer): """GET /analyze/{company} with a supported model succeeds.""" mock_result = CompanyAnalysisResult( company_name="nvidia", analysis="test", patent_count=1, success=True, timestamp=datetime.now(), model="anthropic/claude-3.5-sonnet", ) mock_analyzer._analyze_company_safe.return_value = mock_result response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header()) assert response.status_code == 200 def test_batch_rejects_unsupported_model(self, client, mock_analyzer): """POST /analyze/batch with unsupported model returns 400.""" response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, headers=_auth_header(), ) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] def test_list_models_returns_supported(self, client): """GET /models returns the allow-list.""" response = client.get("/models") assert response.status_code == 200 data = response.json() assert "models" in data assert "default" in data assert len(data["models"]) > 0 assert all("id" in m and "name" in m and "provider" in m for m in data["models"])