6165d66760
The scheduler was refactored (PR #1665) to use the pooled get_db_client() from SPARC.auth instead of creating its own DatabaseClient. Update test mocks accordingly and remove the db.close() assertion since the pooled client is no longer closed by the scheduler. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
388 lines
14 KiB
Python
388 lines
14 KiB
Python
"""Tests for tracked company admin endpoints and scheduler integration.
|
|
|
|
Covers issue #1656:
|
|
- GET /admin/tracked (list tracked companies)
|
|
- POST /admin/tracked (add a tracked company)
|
|
- DELETE /admin/tracked/{company_name} (remove a tracked company)
|
|
- GET /admin/alerts (list alerts)
|
|
- scheduler.run_scheduled_analysis() integration
|
|
|
|
All tests mock the database layer and use JWT auth fixtures.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch, call
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from SPARC.api import app
|
|
from SPARC.auth import create_access_token
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_db():
|
|
"""Mock the database client used by admin and auth endpoints."""
|
|
db = MagicMock()
|
|
|
|
# Default admin user for auth
|
|
db.get_user_by_id.return_value = {
|
|
"id": 1,
|
|
"email": "admin@test.com",
|
|
"role": "admin",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
with patch("SPARC.api.get_db_client", return_value=db), \
|
|
patch("SPARC.auth.get_db_client", return_value=db):
|
|
yield db
|
|
|
|
|
|
def _admin_header():
|
|
"""Create an Authorization header with a valid admin access token."""
|
|
token = create_access_token(1, "admin@test.com", "admin")
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
def _user_header():
|
|
"""Create an Authorization header with a regular user access token."""
|
|
token = create_access_token(2, "user@test.com", "user")
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# ---------- GET /admin/tracked ----------
|
|
|
|
class TestListTrackedCompanies:
|
|
"""GET /admin/tracked"""
|
|
|
|
def test_list_tracked_returns_companies(self, client, mock_db):
|
|
"""Admin can list tracked companies."""
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"},
|
|
{"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"},
|
|
]
|
|
|
|
response = client.get("/admin/tracked", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 2
|
|
assert data[0]["company_name"] == "NVIDIA"
|
|
|
|
def test_list_tracked_empty(self, client, mock_db):
|
|
"""Returns empty list when no companies are tracked."""
|
|
mock_db.list_tracked_companies.return_value = []
|
|
|
|
response = client.get("/admin/tracked", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == []
|
|
|
|
def test_list_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot access tracked companies list."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.get("/admin/tracked", headers=_user_header())
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_list_tracked_unauthenticated(self, client):
|
|
"""Unauthenticated request returns 401."""
|
|
response = client.get("/admin/tracked")
|
|
assert response.status_code == 401
|
|
|
|
|
|
# ---------- POST /admin/tracked ----------
|
|
|
|
class TestAddTrackedCompany:
|
|
"""POST /admin/tracked"""
|
|
|
|
def test_add_tracked_company_success(self, client, mock_db):
|
|
"""Admin can add a company to tracking."""
|
|
mock_db.add_tracked_company.return_value = {
|
|
"company_name": "Intel",
|
|
"last_patent_count": 0,
|
|
"last_analyzed": None,
|
|
}
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "Intel"},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["company_name"] == "Intel"
|
|
mock_db.add_tracked_company.assert_called_once_with("Intel")
|
|
|
|
def test_add_duplicate_returns_409(self, client, mock_db):
|
|
"""Adding an already-tracked company returns 409."""
|
|
mock_db.add_tracked_company.return_value = None
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "NVIDIA"},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
assert "already tracked" in response.json()["detail"].lower()
|
|
|
|
def test_add_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot add tracked companies."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "Intel"},
|
|
headers=_user_header(),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_add_tracked_empty_name_rejected(self, client):
|
|
"""Empty company name is rejected by validation."""
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": ""},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 422 # Pydantic validation error
|
|
|
|
|
|
# ---------- DELETE /admin/tracked/{company_name} ----------
|
|
|
|
class TestRemoveTrackedCompany:
|
|
"""DELETE /admin/tracked/{company_name}"""
|
|
|
|
def test_remove_tracked_company_success(self, client, mock_db):
|
|
"""Admin can remove a tracked company."""
|
|
mock_db.remove_tracked_company.return_value = True
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/NVIDIA",
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert "Stopped tracking" in response.json()["message"]
|
|
mock_db.remove_tracked_company.assert_called_once_with("NVIDIA")
|
|
|
|
def test_remove_nonexistent_returns_404(self, client, mock_db):
|
|
"""Removing a non-tracked company returns 404."""
|
|
mock_db.remove_tracked_company.return_value = False
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/UnknownCorp",
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_remove_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot remove tracked companies."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/NVIDIA",
|
|
headers=_user_header(),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
# ---------- GET /admin/alerts ----------
|
|
|
|
class TestListAlerts:
|
|
"""GET /admin/alerts"""
|
|
|
|
def test_list_alerts_returns_data(self, client, mock_db):
|
|
"""Admin can list alerts."""
|
|
mock_db.list_alerts.return_value = [
|
|
{
|
|
"id": 1,
|
|
"company_name": "NVIDIA",
|
|
"alert_type": "patent_count_change",
|
|
"message": "Patent count increased by 25%",
|
|
"created_at": "2025-06-15T10:00:00Z",
|
|
},
|
|
]
|
|
|
|
response = client.get("/admin/alerts", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 1
|
|
assert data[0]["alert_type"] == "patent_count_change"
|
|
|
|
def test_list_alerts_with_limit(self, client, mock_db):
|
|
"""Custom limit parameter is passed to the database."""
|
|
mock_db.list_alerts.return_value = []
|
|
|
|
response = client.get("/admin/alerts?limit=10", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
mock_db.list_alerts.assert_called_once_with(limit=10)
|
|
|
|
def test_list_alerts_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot access alerts."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.get("/admin/alerts", headers=_user_header())
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
# ---------- Scheduler integration ----------
|
|
|
|
class TestSchedulerIntegration:
|
|
"""Tests for scheduler.run_scheduled_analysis()."""
|
|
|
|
def test_no_tracked_companies_skips_analysis(self):
|
|
"""Scheduler does nothing when no companies are tracked."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = []
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls:
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_analyzer_cls.assert_not_called()
|
|
|
|
def test_scheduler_analyzes_each_tracked_company(self):
|
|
"""Scheduler runs analysis for every tracked company."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "NVIDIA", "last_patent_count": 100},
|
|
{"company_name": "AMD", "last_patent_count": 50},
|
|
]
|
|
|
|
mock_result_nvidia = MagicMock(success=True, patent_count=110)
|
|
mock_result_amd = MagicMock(success=True, patent_count=55)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
assert mock_analyzer._analyze_company_safe.call_count == 2
|
|
mock_db.update_tracked_company.assert_any_call("NVIDIA", 110)
|
|
mock_db.update_tracked_company.assert_any_call("AMD", 55)
|
|
|
|
def test_scheduler_triggers_alert_on_significant_change(self):
|
|
"""Scheduler stores an alert when patent count changes significantly."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "Tesla", "last_patent_count": 100},
|
|
]
|
|
|
|
mock_result = MagicMock(success=True, patent_count=130) # 30% increase
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_db.store_alert.assert_called_once()
|
|
alert_kwargs = mock_db.store_alert.call_args
|
|
assert alert_kwargs[1]["company_name"] == "Tesla"
|
|
assert alert_kwargs[1]["alert_type"] == "patent_count_change"
|
|
assert alert_kwargs[1]["old_value"] == 100
|
|
assert alert_kwargs[1]["new_value"] == 130
|
|
|
|
def test_scheduler_no_alert_for_small_change(self):
|
|
"""Scheduler does not alert when change is below threshold."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "Intel", "last_patent_count": 100},
|
|
]
|
|
|
|
mock_result = MagicMock(success=True, patent_count=105) # 5% increase
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_db.store_alert.assert_not_called()
|
|
|
|
def test_scheduler_handles_analysis_failure(self):
|
|
"""Scheduler continues when one company fails analysis."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "FailCo", "last_patent_count": 50},
|
|
{"company_name": "SuccessCo", "last_patent_count": 30},
|
|
]
|
|
|
|
mock_fail_result = MagicMock(success=False, error="API timeout")
|
|
mock_ok_result = MagicMock(success=True, patent_count=35)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
# FailCo should not get updated, SuccessCo should
|
|
mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35)
|
|
|
|
def test_scheduler_handles_exception_in_analysis(self):
|
|
"""Scheduler continues even when analysis raises an exception."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "CrashCo", "last_patent_count": 10},
|
|
{"company_name": "OKCo", "last_patent_count": 20},
|
|
]
|
|
|
|
mock_ok_result = MagicMock(success=True, patent_count=22)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [
|
|
RuntimeError("unexpected error"),
|
|
mock_ok_result,
|
|
]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
# OKCo should still be processed
|
|
mock_db.update_tracked_company.assert_called_once_with("OKCo", 22)
|