From fc942b2aa44e1009fa4a1413946302ed96bc7bad Mon Sep 17 00:00:00 2001 From: agent-company Date: Mon, 20 Apr 2026 19:14:29 +0000 Subject: [PATCH 1/2] Add tests for tracked company admin endpoints and scheduler integration 20 test cases covering: - GET/POST/DELETE /admin/tracked endpoints with admin auth enforcement - GET /admin/alerts with limit parameter and auth - scheduler.run_scheduled_analysis() for multi-company analysis, alert triggering on significant patent count changes, graceful failure handling Closes leeworks-agents/SPARC#1656 Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_tracked_companies.py | 388 ++++++++++++++++++++++++++++++++ 1 file changed, 388 insertions(+) create mode 100644 tests/test_tracked_companies.py diff --git a/tests/test_tracked_companies.py b/tests/test_tracked_companies.py new file mode 100644 index 0000000..d1e96fa --- /dev/null +++ b/tests/test_tracked_companies.py @@ -0,0 +1,388 @@ +"""Tests for tracked company admin endpoints and scheduler integration. + +Covers issue #1656: +- GET /admin/tracked (list tracked companies) +- POST /admin/tracked (add a tracked company) +- DELETE /admin/tracked/{company_name} (remove a tracked company) +- GET /admin/alerts (list alerts) +- scheduler.run_scheduled_analysis() integration + +All tests mock the database layer and use JWT auth fixtures. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch, call + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import create_access_token + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by admin and auth endpoints.""" + db = MagicMock() + + # Default admin user for auth + db.get_user_by_id.return_value = { + "id": 1, + "email": "admin@test.com", + "role": "admin", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +def _admin_header(): + """Create an Authorization header with a valid admin access token.""" + token = create_access_token(1, "admin@test.com", "admin") + return {"Authorization": f"Bearer {token}"} + + +def _user_header(): + """Create an Authorization header with a regular user access token.""" + token = create_access_token(2, "user@test.com", "user") + return {"Authorization": f"Bearer {token}"} + + +# ---------- GET /admin/tracked ---------- + +class TestListTrackedCompanies: + """GET /admin/tracked""" + + def test_list_tracked_returns_companies(self, client, mock_db): + """Admin can list tracked companies.""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"}, + {"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"}, + ] + + response = client.get("/admin/tracked", headers=_admin_header()) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["company_name"] == "NVIDIA" + + def test_list_tracked_empty(self, client, mock_db): + """Returns empty list when no companies are tracked.""" + mock_db.list_tracked_companies.return_value = [] + + response = client.get("/admin/tracked", headers=_admin_header()) + + assert response.status_code == 200 + assert response.json() == [] + + def test_list_tracked_requires_admin(self, client, mock_db): + """Regular user cannot access tracked companies list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.get("/admin/tracked", headers=_user_header()) + + assert response.status_code == 403 + + def test_list_tracked_unauthenticated(self, client): + """Unauthenticated request returns 401.""" + response = client.get("/admin/tracked") + assert response.status_code == 401 + + +# ---------- POST /admin/tracked ---------- + +class TestAddTrackedCompany: + """POST /admin/tracked""" + + def test_add_tracked_company_success(self, client, mock_db): + """Admin can add a company to tracking.""" + mock_db.add_tracked_company.return_value = { + "company_name": "Intel", + "last_patent_count": 0, + "last_analyzed": None, + } + + response = client.post( + "/admin/tracked", + json={"company_name": "Intel"}, + headers=_admin_header(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["company_name"] == "Intel" + mock_db.add_tracked_company.assert_called_once_with("Intel") + + def test_add_duplicate_returns_409(self, client, mock_db): + """Adding an already-tracked company returns 409.""" + mock_db.add_tracked_company.return_value = None + + response = client.post( + "/admin/tracked", + json={"company_name": "NVIDIA"}, + headers=_admin_header(), + ) + + assert response.status_code == 409 + assert "already tracked" in response.json()["detail"].lower() + + def test_add_tracked_requires_admin(self, client, mock_db): + """Regular user cannot add tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/admin/tracked", + json={"company_name": "Intel"}, + headers=_user_header(), + ) + + assert response.status_code == 403 + + def test_add_tracked_empty_name_rejected(self, client): + """Empty company name is rejected by validation.""" + response = client.post( + "/admin/tracked", + json={"company_name": ""}, + headers=_admin_header(), + ) + + assert response.status_code == 422 # Pydantic validation error + + +# ---------- DELETE /admin/tracked/{company_name} ---------- + +class TestRemoveTrackedCompany: + """DELETE /admin/tracked/{company_name}""" + + def test_remove_tracked_company_success(self, client, mock_db): + """Admin can remove a tracked company.""" + mock_db.remove_tracked_company.return_value = True + + response = client.delete( + "/admin/tracked/NVIDIA", + headers=_admin_header(), + ) + + assert response.status_code == 200 + assert "Stopped tracking" in response.json()["message"] + mock_db.remove_tracked_company.assert_called_once_with("NVIDIA") + + def test_remove_nonexistent_returns_404(self, client, mock_db): + """Removing a non-tracked company returns 404.""" + mock_db.remove_tracked_company.return_value = False + + response = client.delete( + "/admin/tracked/UnknownCorp", + headers=_admin_header(), + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_remove_tracked_requires_admin(self, client, mock_db): + """Regular user cannot remove tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.delete( + "/admin/tracked/NVIDIA", + headers=_user_header(), + ) + + assert response.status_code == 403 + + +# ---------- GET /admin/alerts ---------- + +class TestListAlerts: + """GET /admin/alerts""" + + def test_list_alerts_returns_data(self, client, mock_db): + """Admin can list alerts.""" + mock_db.list_alerts.return_value = [ + { + "id": 1, + "company_name": "NVIDIA", + "alert_type": "patent_count_change", + "message": "Patent count increased by 25%", + "created_at": "2025-06-15T10:00:00Z", + }, + ] + + response = client.get("/admin/alerts", headers=_admin_header()) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["alert_type"] == "patent_count_change" + + def test_list_alerts_with_limit(self, client, mock_db): + """Custom limit parameter is passed to the database.""" + mock_db.list_alerts.return_value = [] + + response = client.get("/admin/alerts?limit=10", headers=_admin_header()) + + assert response.status_code == 200 + mock_db.list_alerts.assert_called_once_with(limit=10) + + def test_list_alerts_requires_admin(self, client, mock_db): + """Regular user cannot access alerts.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + response = client.get("/admin/alerts", headers=_user_header()) + + assert response.status_code == 403 + + +# ---------- Scheduler integration ---------- + +class TestSchedulerIntegration: + """Tests for scheduler.run_scheduled_analysis().""" + + def test_no_tracked_companies_skips_analysis(self): + """Scheduler does nothing when no companies are tracked.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [] + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls: + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_analyzer_cls.assert_not_called() + + def test_scheduler_analyzes_each_tracked_company(self): + """Scheduler runs analysis for every tracked company.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "NVIDIA", "last_patent_count": 100}, + {"company_name": "AMD", "last_patent_count": 50}, + ] + + mock_result_nvidia = MagicMock(success=True, patent_count=110) + mock_result_amd = MagicMock(success=True, patent_count=55) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd] + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + assert mock_analyzer._analyze_company_safe.call_count == 2 + mock_db.update_tracked_company.assert_any_call("NVIDIA", 110) + mock_db.update_tracked_company.assert_any_call("AMD", 55) + + def test_scheduler_triggers_alert_on_significant_change(self): + """Scheduler stores an alert when patent count changes significantly.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "Tesla", "last_patent_count": 100}, + ] + + mock_result = MagicMock(success=True, patent_count=130) # 30% increase + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.return_value = mock_result + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_db.store_alert.assert_called_once() + alert_kwargs = mock_db.store_alert.call_args + assert alert_kwargs[1]["company_name"] == "Tesla" + assert alert_kwargs[1]["alert_type"] == "patent_count_change" + assert alert_kwargs[1]["old_value"] == 100 + assert alert_kwargs[1]["new_value"] == 130 + + def test_scheduler_no_alert_for_small_change(self): + """Scheduler does not alert when change is below threshold.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "Intel", "last_patent_count": 100}, + ] + + mock_result = MagicMock(success=True, patent_count=105) # 5% increase + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.return_value = mock_result + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + mock_db.store_alert.assert_not_called() + + def test_scheduler_handles_analysis_failure(self): + """Scheduler continues when one company fails analysis.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "FailCo", "last_patent_count": 50}, + {"company_name": "SuccessCo", "last_patent_count": 30}, + ] + + mock_fail_result = MagicMock(success=False, error="API timeout") + mock_ok_result = MagicMock(success=True, patent_count=35) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result] + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + # FailCo should not get updated, SuccessCo should + mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35) + + def test_scheduler_handles_exception_in_analysis(self): + """Scheduler continues even when analysis raises an exception.""" + mock_db = MagicMock() + mock_db.list_tracked_companies.return_value = [ + {"company_name": "CrashCo", "last_patent_count": 10}, + {"company_name": "OKCo", "last_patent_count": 20}, + ] + + mock_ok_result = MagicMock(success=True, patent_count=22) + mock_analyzer = MagicMock() + mock_analyzer._analyze_company_safe.side_effect = [ + RuntimeError("unexpected error"), + mock_ok_result, + ] + + with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): + from SPARC.scheduler import run_scheduled_analysis + run_scheduled_analysis() + + # OKCo should still be processed + mock_db.update_tracked_company.assert_called_once_with("OKCo", 22) + mock_db.close.assert_called_once() -- 2.52.0 From 6165d66760cc69d16da25ca372c6de77f8c1c589 Mon Sep 17 00:00:00 2001 From: agent-company Date: Mon, 20 Apr 2026 23:05:42 +0000 Subject: [PATCH 2/2] Fix scheduler tests to use get_db_client after scheduler refactor The scheduler was refactored (PR #1665) to use the pooled get_db_client() from SPARC.auth instead of creating its own DatabaseClient. Update test mocks accordingly and remove the db.close() assertion since the pooled client is no longer closed by the scheduler. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_tracked_companies.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_tracked_companies.py b/tests/test_tracked_companies.py index d1e96fa..df25134 100644 --- a/tests/test_tracked_companies.py +++ b/tests/test_tracked_companies.py @@ -272,7 +272,7 @@ class TestSchedulerIntegration: mock_db = MagicMock() mock_db.list_tracked_companies.return_value = [] - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls: from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() @@ -292,7 +292,7 @@ class TestSchedulerIntegration: mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd] - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() @@ -312,7 +312,7 @@ class TestSchedulerIntegration: mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() @@ -335,7 +335,7 @@ class TestSchedulerIntegration: mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.return_value = mock_result - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() @@ -355,7 +355,7 @@ class TestSchedulerIntegration: mock_analyzer = MagicMock() mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result] - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() @@ -378,11 +378,10 @@ class TestSchedulerIntegration: mock_ok_result, ] - with patch("SPARC.scheduler.DatabaseClient", return_value=mock_db), \ + with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \ patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer): from SPARC.scheduler import run_scheduled_analysis run_scheduled_analysis() # OKCo should still be processed mock_db.update_tracked_company.assert_called_once_with("OKCo", 22) - mock_db.close.assert_called_once() -- 2.52.0