"""Tests for tracked company endpoints and scheduler integration. Covers: - GET /tracked (user-scoped list) - POST /tracked (user-scoped add) - DELETE /tracked/{company_name} (user-scoped remove) - GET /admin/tracked (admin: all companies) - POST /admin/tracked (admin: add) - DELETE /admin/tracked/{company_name} (admin: remove any) - GET /admin/alerts (list alerts) - scheduler.run_scheduled_analysis() integration """ from datetime import datetime, timezone from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app from SPARC.auth import create_access_token @pytest.fixture def client(): """Create test client.""" return TestClient(app) @pytest.fixture(autouse=True) def mock_db(): """Mock the database client used by admin and auth endpoints.""" db = MagicMock() # Default admin user for auth db.get_user_by_id.return_value = { "id": 1, "email": "admin@test.com", "role": "admin", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } with patch("SPARC.api.get_db_client", return_value=db), \ patch("SPARC.auth.get_db_client", return_value=db): yield db def _admin_header(): """Create an Authorization header with a valid admin access token.""" token = create_access_token(1, "admin@test.com", "admin") return {"Authorization": f"Bearer {token}"} def _user_header(): """Create an Authorization header with a regular user access token.""" token = create_access_token(2, "user@test.com", "user") return {"Authorization": f"Bearer {token}"} # ---------- GET /admin/tracked ---------- class TestListTrackedCompanies: """GET /admin/tracked""" def test_list_tracked_returns_companies(self, client, mock_db): """Admin can list tracked companies.""" mock_db.list_tracked_companies.return_value = [ {"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"}, {"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"}, ] response = client.get("/admin/tracked", headers=_admin_header()) assert response.status_code == 200 data = response.json() assert len(data) == 2 assert data[0]["company_name"] == "NVIDIA" def test_list_tracked_empty(self, client, mock_db): """Returns empty list when no companies are tracked.""" mock_db.list_tracked_companies.return_value = [] response = client.get("/admin/tracked", headers=_admin_header()) assert response.status_code == 200 assert response.json() == [] def test_list_tracked_requires_admin(self, client, mock_db): """Regular user cannot access tracked companies list.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.get("/admin/tracked", headers=_user_header()) assert response.status_code == 403 def test_list_tracked_unauthenticated(self, client): """Unauthenticated request returns 401.""" response = client.get("/admin/tracked") assert response.status_code == 401 # ---------- POST /admin/tracked ---------- class TestAddTrackedCompany: """POST /admin/tracked""" def test_add_tracked_company_success(self, client, mock_db): """Admin can add a company to tracking.""" mock_db.add_tracked_company.return_value = { "company_name": "Intel", "last_patent_count": 0, "last_analyzed": None, } response = client.post( "/admin/tracked", json={"company_name": "Intel"}, headers=_admin_header(), ) assert response.status_code == 200 data = response.json() assert data["company_name"] == "Intel" mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1) def test_add_duplicate_returns_409(self, client, mock_db): """Adding an already-tracked company returns 409.""" mock_db.add_tracked_company.return_value = None response = client.post( "/admin/tracked", json={"company_name": "NVIDIA"}, headers=_admin_header(), ) assert response.status_code == 409 assert "already tracked" in response.json()["detail"].lower() def test_add_tracked_requires_admin(self, client, mock_db): """Regular user cannot add tracked companies via admin endpoint.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.post( "/admin/tracked", json={"company_name": "Intel"}, headers=_user_header(), ) assert response.status_code == 403 def test_add_tracked_empty_name_rejected(self, client): """Empty company name is rejected by validation.""" response = client.post( "/admin/tracked", json={"company_name": ""}, headers=_admin_header(), ) assert response.status_code == 422 # Pydantic validation error # ---------- DELETE /admin/tracked/{company_name} ---------- class TestRemoveTrackedCompany: """DELETE /admin/tracked/{company_name}""" def test_remove_tracked_company_success(self, client, mock_db): """Admin can remove a tracked company.""" mock_db.remove_tracked_company.return_value = True response = client.delete( "/admin/tracked/NVIDIA", headers=_admin_header(), ) assert response.status_code == 200 assert "Stopped tracking" in response.json()["message"] mock_db.remove_tracked_company.assert_called_once_with("NVIDIA") def test_remove_nonexistent_returns_404(self, client, mock_db): """Removing a non-tracked company returns 404.""" mock_db.remove_tracked_company.return_value = False response = client.delete( "/admin/tracked/UnknownCorp", headers=_admin_header(), ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_remove_tracked_requires_admin(self, client, mock_db): """Regular user cannot remove tracked companies.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.delete( "/admin/tracked/NVIDIA", headers=_user_header(), ) assert response.status_code == 403 # ---------- User-scoped tracked companies ---------- class TestUserScopedTrackedCompanies: """Tests for /tracked user-scoped endpoints.""" def test_user_list_tracked(self, client, mock_db): """Regular user can list their own tracked companies.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } mock_db.list_tracked_companies.return_value = [ {"company_name": "AMD", "owner_id": 2}, ] response = client.get("/tracked", headers=_user_header()) assert response.status_code == 200 mock_db.list_tracked_companies.assert_called_with(owner_id=2) def test_user_add_tracked(self, client, mock_db): """Regular user can add a company to their own tracked list.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } mock_db.add_tracked_company.return_value = { "company_name": "Intel", "owner_id": 2, } response = client.post( "/tracked", json={"company_name": "Intel"}, headers=_user_header(), ) assert response.status_code == 200 mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2) def test_user_remove_tracked(self, client, mock_db): """Regular user can remove a company from their own tracked list.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } mock_db.remove_tracked_company.return_value = True response = client.delete("/tracked/Intel", headers=_user_header()) assert response.status_code == 200 mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2) # ---------- GET /admin/alerts ---------- class TestListAlerts: """GET /admin/alerts""" def test_list_alerts_returns_data(self, client, mock_db): """Admin can list alerts.""" mock_db.list_alerts.return_value = [ { "id": 1, "company_name": "NVIDIA", "alert_type": "patent_count_change", "message": "Patent count increased by 25%", "created_at": "2025-06-15T10:00:00Z", }, ] response = client.get("/admin/alerts", headers=_admin_header()) assert response.status_code == 200 data = response.json() assert len(data) == 1 assert data[0]["alert_type"] == "patent_count_change" def test_list_alerts_with_limit(self, client, mock_db): """Custom limit parameter is passed to the database.""" mock_db.list_alerts.return_value = [] response = client.get("/admin/alerts?limit=10", headers=_admin_header()) assert response.status_code == 200 mock_db.list_alerts.assert_called_once_with(limit=10) def test_list_alerts_requires_admin(self, client, mock_db): """Regular user cannot access alerts.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.get("/admin/alerts", headers=_user_header()) assert response.status_code == 403 # ---------- Scheduler integration ---------- class TestSchedulerIntegration: """Tests for scheduler.run_scheduled_analysis().""" def test_no_tracked_companies_skips_analysis(self): """Scheduler does nothing when no companies are tracked.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [] with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls: from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_analyzer_cls.assert_not_called() def test_scheduler_analyzes_each_tracked_company(self): """Scheduler runs analysis for every tracked company.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "NVIDIA", "last_patent_count": 100}, {"company_name": "AMD", "last_patent_count": 50}, ] mock_result_nvidia = MagicMock(success=True, patent_count=110) mock_result_amd = MagicMock(success=True, patent_count=55) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd] with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() assert mock_analyzer._analyze_company_safe.call_count == 2 mock_db.update_tracked_company.assert_any_call("NVIDIA", 110) mock_db.update_tracked_company.assert_any_call("AMD", 55) def test_scheduler_triggers_alert_on_significant_change(self): """Scheduler stores an alert when patent count changes significantly.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "Tesla", "last_patent_count": 100}, ] mock_result = MagicMock(success=True, patent_count=130) # 30% increase mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_db.store_alert.assert_called_once() alert_kwargs = mock_db.store_alert.call_args assert alert_kwargs[1]["company_name"] == "Tesla" assert alert_kwargs[1]["alert_type"] == "patent_count_change" assert alert_kwargs[1]["old_value"] == 100 assert alert_kwargs[1]["new_value"] == 130 def test_scheduler_no_alert_for_small_change(self): """Scheduler does not alert when change is below threshold.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "Intel", "last_patent_count": 100}, ] mock_result = MagicMock(success=True, patent_count=105) # 5% increase mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_db.store_alert.assert_not_called() def test_scheduler_handles_analysis_failure(self): """Scheduler continues when one company fails analysis.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "FailCo", "last_patent_count": 50}, {"company_name": "SuccessCo", "last_patent_count": 30}, ] mock_fail_result = MagicMock(success=False, error="API timeout") mock_ok_result = MagicMock(success=True, patent_count=35) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result] with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() # FailCo should not get updated, SuccessCo should mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35) def test_scheduler_handles_exception_in_analysis(self): """Scheduler continues even when analysis raises an exception.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "CrashCo", "last_patent_count": 10}, {"company_name": "OKCo", "last_patent_count": 20}, ] mock_ok_result = MagicMock(success=True, patent_count=22) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [ RuntimeError("unexpected error"), mock_ok_result, ] with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() # OKCo should still be processed mock_db.update_tracked_company.assert_called_once_with("OKCo", 22)