diff --git a/tests/test_tracked_companies.py b/tests/test_tracked_companies.py new file mode 100644 index 0000000..df25134 --- /dev/null +++ b/tests/test_tracked_companies.py @@ -0,0 +1,387 @@ +"""Tests for tracked company admin endpoints and scheduler integration. + +Covers issue #1656: +- GET /admin/tracked (list tracked companies) +- POST /admin/tracked (add a tracked company) +- DELETE /admin/tracked/{company_name} (remove a tracked company) +- GET /admin/alerts (list alerts) +- scheduler.run_scheduled_analysis() integration + +All tests mock the database layer and use JWT auth fixtures. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch, call + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import create_access_token + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by admin and auth endpoints.""" + db = MagicMock() + + # Default admin user for auth + db.get_user_by_id.return_value = { + "id": 1, + "email": "admin@test.com", + "role": "admin", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +def _admin_header(): + """Create an Authorization header with a valid admin access token.""" + token = create_access_token(1, "admin@test.com", "admin") + return {"Authorization": f"Bearer {token}"} + + +def _user_header(): + """Create an Authorization header with a regular user access token.""" + token = create_access_token(2, "user@test.com", "user") + return {"Authorization": f"Bearer {token}"} + + +# ---------- GET /admin/tracked ---------- + +class TestListTrackedCompanies: + """GET /admin/tracked""" + + def test_list_tracked_returns_companies(self, client, mock_db): + """Admin can list tracked companies.""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"}, + {"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"}, + ] + + response = client.get("/admin/tracked", headers=_admin_header()) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["company_name"] == "NVIDIA" + + def test_list_tracked_empty(self, client, mock_db): + """Returns empty list when no companies are tracked.""" + mock_db.list_tracked_companies.return_value = [] + + response = client.get("/admin/tracked", headers=_admin_header()) + + assert response.status_code == 200 + assert response.json() == [] + + def test_list_tracked_requires_admin(self, client, mock_db): + """Regular user cannot access tracked companies list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.get("/admin/tracked", headers=_user_header()) + + assert response.status_code == 403 + + def test_list_tracked_unauthenticated(self, client): + """Unauthenticated request returns 401.""" + response = client.get("/admin/tracked") + assert response.status_code == 401 + + +# ---------- POST /admin/tracked ---------- + +class TestAddTrackedCompany: + """POST /admin/tracked""" + + def test_add_tracked_company_success(self, client, mock_db): + """Admin can add a company to tracking.""" + mock_db.add_tracked_company.return_value = { + "company_name": "Intel", + "last_patent_count": 0, + "last_analyzed": None, + } + + response = client.post( + "/admin/tracked", + json={"company_name": "Intel"}, + headers=_admin_header(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["company_name"] == "Intel" + mock_db.add_tracked_company.assert_called_once_with("Intel") + + def test_add_duplicate_returns_409(self, client, mock_db): + """Adding an already-tracked company returns 409.""" + mock_db.add_tracked_company.return_value = None + + response = client.post( + "/admin/tracked", + json={"company_name": "NVIDIA"}, + headers=_admin_header(), + ) + + assert response.status_code == 409 + assert "already tracked" in response.json()["detail"].lower() + + def test_add_tracked_requires_admin(self, client, mock_db): + """Regular user cannot add tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/admin/tracked", + json={"company_name": "Intel"}, + headers=_user_header(), + ) + + assert response.status_code == 403 + + def test_add_tracked_empty_name_rejected(self, client): + """Empty company name is rejected by validation.""" + response = client.post( + "/admin/tracked", + json={"company_name": ""}, + headers=_admin_header(), + ) + + assert response.status_code == 422 # Pydantic validation error + + +# ---------- DELETE /admin/tracked/{company_name} ---------- + +class TestRemoveTrackedCompany: + """DELETE /admin/tracked/{company_name}""" + + def test_remove_tracked_company_success(self, client, mock_db): + """Admin can remove a tracked company.""" + mock_db.remove_tracked_company.return_value = True + + response = client.delete( + "/admin/tracked/NVIDIA", + headers=_admin_header(), + ) + + assert response.status_code == 200 + assert "Stopped tracking" in response.json()["message"] + mock_db.remove_tracked_company.assert_called_once_with("NVIDIA") + + def test_remove_nonexistent_returns_404(self, client, mock_db): + """Removing a non-tracked company returns 404.""" + mock_db.remove_tracked_company.return_value = False + + response = client.delete( + "/admin/tracked/UnknownCorp", + headers=_admin_header(), + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_remove_tracked_requires_admin(self, client, mock_db): + """Regular user cannot remove tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.delete( + "/admin/tracked/NVIDIA", + headers=_user_header(), + ) + + assert response.status_code == 403 + + +# ---------- GET /admin/alerts ---------- + +class TestListAlerts: + """GET /admin/alerts""" + + def test_list_alerts_returns_data(self, client, mock_db): + """Admin can list alerts.""" + mock_db.list_alerts.return_value = [ + { + "id": 1, + "company_name": "NVIDIA", + "alert_type": "patent_count_change", + "message": "Patent count increased by 25%", + "created_at": "2025-06-15T10:00:00Z", + }, + ] + + response = client.get("/admin/alerts", headers=_admin_header()) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["alert_type"] == "patent_count_change" + + def test_list_alerts_with_limit(self, client, mock_db): + """Custom limit parameter is passed to the database.""" + mock_db.list_alerts.return_value = [] + + response = client.get("/admin/alerts?limit=10", headers=_admin_header()) + + assert response.status_code == 200 + mock_db.list_alerts.assert_called_once_with(limit=10) + + def test_list_alerts_requires_admin(self, client, mock_db): + """Regular user cannot access alerts.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.get("/admin/alerts", headers=_user_header()) + + assert response.status_code == 403 + + +# ---------- Scheduler integration ---------- + +class TestSchedulerIntegration: + """Tests for scheduler.run_scheduled_analysis().""" + + def test_no_tracked_companies_skips_analysis(self): + """Scheduler does nothing when no companies are tracked.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [] + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls: + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_analyzer_cls.assert_not_called() + + def test_scheduler_analyzes_each_tracked_company(self): + """Scheduler runs analysis for every tracked company.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "NVIDIA", "last_patent_count": 100}, + {"company_name": "AMD", "last_patent_count": 50}, + ] + + mock_result_nvidia = MagicMock(success=True, patent_count=110) + mock_result_amd = MagicMock(success=True, patent_count=55) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd] + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + assert mock_analyzer._analyze_company_safe.call_count == 2 + mock_db.update_tracked_company.assert_any_call("NVIDIA", 110) + mock_db.update_tracked_company.assert_any_call("AMD", 55) + + def test_scheduler_triggers_alert_on_significant_change(self): + """Scheduler stores an alert when patent count changes significantly.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "Tesla", "last_patent_count": 100}, + ] + + mock_result = MagicMock(success=True, patent_count=130) # 30% increase + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.return_value = mock_result + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_db.store_alert.assert_called_once() + alert_kwargs = mock_db.store_alert.call_args + assert alert_kwargs[1]["company_name"] == "Tesla" + assert alert_kwargs[1]["alert_type"] == "patent_count_change" + assert alert_kwargs[1]["old_value"] == 100 + assert alert_kwargs[1]["new_value"] == 130 + + def test_scheduler_no_alert_for_small_change(self): + """Scheduler does not alert when change is below threshold.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "Intel", "last_patent_count": 100}, + ] + + mock_result = MagicMock(success=True, patent_count=105) # 5% increase + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.return_value = mock_result + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_db.store_alert.assert_not_called() + + def test_scheduler_handles_analysis_failure(self): + """Scheduler continues when one company fails analysis.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "FailCo", "last_patent_count": 50}, + {"company_name": "SuccessCo", "last_patent_count": 30}, + ] + + mock_fail_result = MagicMock(success=False, error="API timeout") + mock_ok_result = MagicMock(success=True, patent_count=35) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result] + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + # FailCo should not get updated, SuccessCo should + mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35) + + def test_scheduler_handles_exception_in_analysis(self): + """Scheduler continues even when analysis raises an exception.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "CrashCo", "last_patent_count": 10}, + {"company_name": "OKCo", "last_patent_count": 20}, + ] + + mock_ok_result = MagicMock(success=True, patent_count=22) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [ + RuntimeError("unexpected error"), + mock_ok_result, + ] + + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + # OKCo should still be processed + mock_db.update_tracked_company.assert_called_once_with("OKCo", 22)