"""Tests for tracked company admin endpoints and scheduler integration. Covers issue #1656: - GET /admin/tracked (list tracked companies) - POST /admin/tracked (add a tracked company) - DELETE /admin/tracked/{company_name} (remove a tracked company) - GET /admin/alerts (list alerts) - scheduler.run_scheduled_analysis() integration All tests mock the database layer and use JWT auth fixtures. """ from datetime import datetime, timezone from unittest.mock import MagicMock, patch, call import pytest from fastapi.testclient import TestClient from SPARC.api import app from SPARC.auth import create_access_token @pytest.fixture def client(): """Create test client.""" return TestClient(app) @pytest.fixture(autouse=True) def mock_db(): """Mock the database client used by admin and auth endpoints.""" db = MagicMock() # Default admin user for auth db.get_user_by_id.return_value = { "id": 1, "email": "admin@test.com", "role": "admin", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } with patch("SPARC.api.get_db_client", return_value=db), \ patch("SPARC.auth.get_db_client", return_value=db): yield db def _admin_header(): """Create an Authorization header with a valid admin access token.""" token = create_access_token(1, "admin@test.com", "admin") return {"Authorization": f"Bearer {token}"} def _user_header(): """Create an Authorization header with a regular user access token.""" token = create_access_token(2, "user@test.com", "user") return {"Authorization": f"Bearer {token}"} # ---------- GET /admin/tracked ---------- class TestListTrackedCompanies: """GET /admin/tracked""" def test_list_tracked_returns_companies(self, client, mock_db): """Admin can list tracked companies.""" mock_db.list_tracked_companies.return_value = [ {"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"}, {"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"}, ] response = client.get("/admin/tracked", headers=_admin_header()) assert response.status_code == 200 data = response.json() assert len(data) == 2 assert data[0]["company_name"] == "NVIDIA" def test_list_tracked_empty(self, client, mock_db): """Returns empty list when no companies are tracked.""" mock_db.list_tracked_companies.return_value = [] response = client.get("/admin/tracked", headers=_admin_header()) assert response.status_code == 200 assert response.json() == [] def test_list_tracked_requires_admin(self, client, mock_db): """Regular user cannot access tracked companies list.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.get("/admin/tracked", headers=_user_header()) assert response.status_code == 403 def test_list_tracked_unauthenticated(self, client): """Unauthenticated request returns 401.""" response = client.get("/admin/tracked") assert response.status_code == 401 # ---------- POST /admin/tracked ---------- class TestAddTrackedCompany: """POST /admin/tracked""" def test_add_tracked_company_success(self, client, mock_db): """Admin can add a company to tracking.""" mock_db.add_tracked_company.return_value = { "company_name": "Intel", "last_patent_count": 0, "last_analyzed": None, } response = client.post( "/admin/tracked", json={"company_name": "Intel"}, headers=_admin_header(), ) assert response.status_code == 200 data = response.json() assert data["company_name"] == "Intel" mock_db.add_tracked_company.assert_called_once_with("Intel") def test_add_duplicate_returns_409(self, client, mock_db): """Adding an already-tracked company returns 409.""" mock_db.add_tracked_company.return_value = None response = client.post( "/admin/tracked", json={"company_name": "NVIDIA"}, headers=_admin_header(), ) assert response.status_code == 409 assert "already tracked" in response.json()["detail"].lower() def test_add_tracked_requires_admin(self, client, mock_db): """Regular user cannot add tracked companies.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.post( "/admin/tracked", json={"company_name": "Intel"}, headers=_user_header(), ) assert response.status_code == 403 def test_add_tracked_empty_name_rejected(self, client): """Empty company name is rejected by validation.""" response = client.post( "/admin/tracked", json={"company_name": ""}, headers=_admin_header(), ) assert response.status_code == 422 # Pydantic validation error # ---------- DELETE /admin/tracked/{company_name} ---------- class TestRemoveTrackedCompany: """DELETE /admin/tracked/{company_name}""" def test_remove_tracked_company_success(self, client, mock_db): """Admin can remove a tracked company.""" mock_db.remove_tracked_company.return_value = True response = client.delete( "/admin/tracked/NVIDIA", headers=_admin_header(), ) assert response.status_code == 200 assert "Stopped tracking" in response.json()["message"] mock_db.remove_tracked_company.assert_called_once_with("NVIDIA") def test_remove_nonexistent_returns_404(self, client, mock_db): """Removing a non-tracked company returns 404.""" mock_db.remove_tracked_company.return_value = False response = client.delete( "/admin/tracked/UnknownCorp", headers=_admin_header(), ) assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() def test_remove_tracked_requires_admin(self, client, mock_db): """Regular user cannot remove tracked companies.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.delete( "/admin/tracked/NVIDIA", headers=_user_header(), ) assert response.status_code == 403 # ---------- GET /admin/alerts ---------- class TestListAlerts: """GET /admin/alerts""" def test_list_alerts_returns_data(self, client, mock_db): """Admin can list alerts.""" mock_db.list_alerts.return_value = [ { "id": 1, "company_name": "NVIDIA", "alert_type": "patent_count_change", "message": "Patent count increased by 25%", "created_at": "2025-06-15T10:00:00Z", }, ] response = client.get("/admin/alerts", headers=_admin_header()) assert response.status_code == 200 data = response.json() assert len(data) == 1 assert data[0]["alert_type"] == "patent_count_change" def test_list_alerts_with_limit(self, client, mock_db): """Custom limit parameter is passed to the database.""" mock_db.list_alerts.return_value = [] response = client.get("/admin/alerts?limit=10", headers=_admin_header()) assert response.status_code == 200 mock_db.list_alerts.assert_called_once_with(limit=10) def test_list_alerts_requires_admin(self, client, mock_db): """Regular user cannot access alerts.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", "role": "user", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } response = client.get("/admin/alerts", headers=_user_header()) assert response.status_code == 403 # ---------- Scheduler integration ---------- class TestSchedulerIntegration: """Tests for scheduler.run_scheduled_analysis().""" def test_no_tracked_companies_skips_analysis(self): """Scheduler does nothing when no companies are tracked.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [] with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls: from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_analyzer_cls.assert_not_called() def test_scheduler_analyzes_each_tracked_company(self): """Scheduler runs analysis for every tracked company.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "NVIDIA", "last_patent_count": 100}, {"company_name": "AMD", "last_patent_count": 50}, ] mock_result_nvidia = MagicMock(success=True, patent_count=110) mock_result_amd = MagicMock(success=True, patent_count=55) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd] with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() assert mock_analyzer._analyze_company_safe.call_count == 2 mock_db.update_tracked_company.assert_any_call("NVIDIA", 110) mock_db.update_tracked_company.assert_any_call("AMD", 55) def test_scheduler_triggers_alert_on_significant_change(self): """Scheduler stores an alert when patent count changes significantly.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "Tesla", "last_patent_count": 100}, ] mock_result = MagicMock(success=True, patent_count=130) # 30% increase mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_db.store_alert.assert_called_once() alert_kwargs = mock_db.store_alert.call_args assert alert_kwargs[1]["company_name"] == "Tesla" assert alert_kwargs[1]["alert_type"] == "patent_count_change" assert alert_kwargs[1]["old_value"] == 100 assert alert_kwargs[1]["new_value"] == 130 def test_scheduler_no_alert_for_small_change(self): """Scheduler does not alert when change is below threshold.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "Intel", "last_patent_count": 100}, ] mock_result = MagicMock(success=True, patent_count=105) # 5% increase mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() mock_db.store_alert.assert_not_called() def test_scheduler_handles_analysis_failure(self): """Scheduler continues when one company fails analysis.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "FailCo", "last_patent_count": 50}, {"company_name": "SuccessCo", "last_patent_count": 30}, ] mock_fail_result = MagicMock(success=False, error="API timeout") mock_ok_result = MagicMock(success=True, patent_count=35) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result] with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() # FailCo should not get updated, SuccessCo should mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35) def test_scheduler_handles_exception_in_analysis(self): """Scheduler continues even when analysis raises an exception.""" mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [ {"company_name": "CrashCo", "last_patent_count": 10}, {"company_name": "OKCo", "last_patent_count": 20}, ] mock_ok_result = MagicMock(success=True, patent_count=22) mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [ RuntimeError("unexpected error"), mock_ok_result, ] with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() # OKCo should still be processed mock_db.update_tracked_company.assert_called_once_with("OKCo", 22) mock_db.close.assert_called_once()