forked from 0xWheatyz/SPARC
e37859dabc
- Add owner_id (FK to users) column to llm_messages, jobs, and
tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
assertions
- Add comprehensive cross-tenant isolation test suite
(tests/test_multi_tenant.py)
Closes leeworks-agents/SPARC#1677
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
449 lines
16 KiB
Python
449 lines
16 KiB
Python
"""Tests for tracked company endpoints and scheduler integration.
|
|
|
|
Covers:
|
|
- GET /tracked (user-scoped list)
|
|
- POST /tracked (user-scoped add)
|
|
- DELETE /tracked/{company_name} (user-scoped remove)
|
|
- GET /admin/tracked (admin: all companies)
|
|
- POST /admin/tracked (admin: add)
|
|
- DELETE /admin/tracked/{company_name} (admin: remove any)
|
|
- GET /admin/alerts (list alerts)
|
|
- scheduler.run_scheduled_analysis() integration
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from SPARC.api import app
|
|
from SPARC.auth import create_access_token
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_db():
|
|
"""Mock the database client used by admin and auth endpoints."""
|
|
db = MagicMock()
|
|
|
|
# Default admin user for auth
|
|
db.get_user_by_id.return_value = {
|
|
"id": 1,
|
|
"email": "admin@test.com",
|
|
"role": "admin",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
with patch("SPARC.api.get_db_client", return_value=db), \
|
|
patch("SPARC.auth.get_db_client", return_value=db):
|
|
yield db
|
|
|
|
|
|
def _admin_header():
|
|
"""Create an Authorization header with a valid admin access token."""
|
|
token = create_access_token(1, "admin@test.com", "admin")
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
def _user_header():
|
|
"""Create an Authorization header with a regular user access token."""
|
|
token = create_access_token(2, "user@test.com", "user")
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# ---------- GET /admin/tracked ----------
|
|
|
|
class TestListTrackedCompanies:
|
|
"""GET /admin/tracked"""
|
|
|
|
def test_list_tracked_returns_companies(self, client, mock_db):
|
|
"""Admin can list tracked companies."""
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"},
|
|
{"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"},
|
|
]
|
|
|
|
response = client.get("/admin/tracked", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 2
|
|
assert data[0]["company_name"] == "NVIDIA"
|
|
|
|
def test_list_tracked_empty(self, client, mock_db):
|
|
"""Returns empty list when no companies are tracked."""
|
|
mock_db.list_tracked_companies.return_value = []
|
|
|
|
response = client.get("/admin/tracked", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == []
|
|
|
|
def test_list_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot access tracked companies list."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.get("/admin/tracked", headers=_user_header())
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_list_tracked_unauthenticated(self, client):
|
|
"""Unauthenticated request returns 401."""
|
|
response = client.get("/admin/tracked")
|
|
assert response.status_code == 401
|
|
|
|
|
|
# ---------- POST /admin/tracked ----------
|
|
|
|
class TestAddTrackedCompany:
|
|
"""POST /admin/tracked"""
|
|
|
|
def test_add_tracked_company_success(self, client, mock_db):
|
|
"""Admin can add a company to tracking."""
|
|
mock_db.add_tracked_company.return_value = {
|
|
"company_name": "Intel",
|
|
"last_patent_count": 0,
|
|
"last_analyzed": None,
|
|
}
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "Intel"},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["company_name"] == "Intel"
|
|
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1)
|
|
|
|
def test_add_duplicate_returns_409(self, client, mock_db):
|
|
"""Adding an already-tracked company returns 409."""
|
|
mock_db.add_tracked_company.return_value = None
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "NVIDIA"},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
assert "already tracked" in response.json()["detail"].lower()
|
|
|
|
def test_add_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot add tracked companies via admin endpoint."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": "Intel"},
|
|
headers=_user_header(),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_add_tracked_empty_name_rejected(self, client):
|
|
"""Empty company name is rejected by validation."""
|
|
response = client.post(
|
|
"/admin/tracked",
|
|
json={"company_name": ""},
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 422 # Pydantic validation error
|
|
|
|
|
|
# ---------- DELETE /admin/tracked/{company_name} ----------
|
|
|
|
class TestRemoveTrackedCompany:
|
|
"""DELETE /admin/tracked/{company_name}"""
|
|
|
|
def test_remove_tracked_company_success(self, client, mock_db):
|
|
"""Admin can remove a tracked company."""
|
|
mock_db.remove_tracked_company.return_value = True
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/NVIDIA",
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert "Stopped tracking" in response.json()["message"]
|
|
mock_db.remove_tracked_company.assert_called_once_with("NVIDIA")
|
|
|
|
def test_remove_nonexistent_returns_404(self, client, mock_db):
|
|
"""Removing a non-tracked company returns 404."""
|
|
mock_db.remove_tracked_company.return_value = False
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/UnknownCorp",
|
|
headers=_admin_header(),
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_remove_tracked_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot remove tracked companies."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.delete(
|
|
"/admin/tracked/NVIDIA",
|
|
headers=_user_header(),
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
# ---------- User-scoped tracked companies ----------
|
|
|
|
class TestUserScopedTrackedCompanies:
|
|
"""Tests for /tracked user-scoped endpoints."""
|
|
|
|
def test_user_list_tracked(self, client, mock_db):
|
|
"""Regular user can list their own tracked companies."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "AMD", "owner_id": 2},
|
|
]
|
|
|
|
response = client.get("/tracked", headers=_user_header())
|
|
|
|
assert response.status_code == 200
|
|
mock_db.list_tracked_companies.assert_called_with(owner_id=2)
|
|
|
|
def test_user_add_tracked(self, client, mock_db):
|
|
"""Regular user can add a company to their own tracked list."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
mock_db.add_tracked_company.return_value = {
|
|
"company_name": "Intel",
|
|
"owner_id": 2,
|
|
}
|
|
|
|
response = client.post(
|
|
"/tracked",
|
|
json={"company_name": "Intel"},
|
|
headers=_user_header(),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2)
|
|
|
|
def test_user_remove_tracked(self, client, mock_db):
|
|
"""Regular user can remove a company from their own tracked list."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
mock_db.remove_tracked_company.return_value = True
|
|
|
|
response = client.delete("/tracked/Intel", headers=_user_header())
|
|
|
|
assert response.status_code == 200
|
|
mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2)
|
|
|
|
|
|
# ---------- GET /admin/alerts ----------
|
|
|
|
class TestListAlerts:
|
|
"""GET /admin/alerts"""
|
|
|
|
def test_list_alerts_returns_data(self, client, mock_db):
|
|
"""Admin can list alerts."""
|
|
mock_db.list_alerts.return_value = [
|
|
{
|
|
"id": 1,
|
|
"company_name": "NVIDIA",
|
|
"alert_type": "patent_count_change",
|
|
"message": "Patent count increased by 25%",
|
|
"created_at": "2025-06-15T10:00:00Z",
|
|
},
|
|
]
|
|
|
|
response = client.get("/admin/alerts", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 1
|
|
assert data[0]["alert_type"] == "patent_count_change"
|
|
|
|
def test_list_alerts_with_limit(self, client, mock_db):
|
|
"""Custom limit parameter is passed to the database."""
|
|
mock_db.list_alerts.return_value = []
|
|
|
|
response = client.get("/admin/alerts?limit=10", headers=_admin_header())
|
|
|
|
assert response.status_code == 200
|
|
mock_db.list_alerts.assert_called_once_with(limit=10)
|
|
|
|
def test_list_alerts_requires_admin(self, client, mock_db):
|
|
"""Regular user cannot access alerts."""
|
|
mock_db.get_user_by_id.return_value = {
|
|
"id": 2,
|
|
"email": "user@test.com",
|
|
"role": "user",
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
|
|
response = client.get("/admin/alerts", headers=_user_header())
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
|
# ---------- Scheduler integration ----------
|
|
|
|
class TestSchedulerIntegration:
|
|
"""Tests for scheduler.run_scheduled_analysis()."""
|
|
|
|
def test_no_tracked_companies_skips_analysis(self):
|
|
"""Scheduler does nothing when no companies are tracked."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = []
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls:
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_analyzer_cls.assert_not_called()
|
|
|
|
def test_scheduler_analyzes_each_tracked_company(self):
|
|
"""Scheduler runs analysis for every tracked company."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "NVIDIA", "last_patent_count": 100},
|
|
{"company_name": "AMD", "last_patent_count": 50},
|
|
]
|
|
|
|
mock_result_nvidia = MagicMock(success=True, patent_count=110)
|
|
mock_result_amd = MagicMock(success=True, patent_count=55)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
assert mock_analyzer._analyze_company_safe.call_count == 2
|
|
mock_db.update_tracked_company.assert_any_call("NVIDIA", 110)
|
|
mock_db.update_tracked_company.assert_any_call("AMD", 55)
|
|
|
|
def test_scheduler_triggers_alert_on_significant_change(self):
|
|
"""Scheduler stores an alert when patent count changes significantly."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "Tesla", "last_patent_count": 100},
|
|
]
|
|
|
|
mock_result = MagicMock(success=True, patent_count=130) # 30% increase
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_db.store_alert.assert_called_once()
|
|
alert_kwargs = mock_db.store_alert.call_args
|
|
assert alert_kwargs[1]["company_name"] == "Tesla"
|
|
assert alert_kwargs[1]["alert_type"] == "patent_count_change"
|
|
assert alert_kwargs[1]["old_value"] == 100
|
|
assert alert_kwargs[1]["new_value"] == 130
|
|
|
|
def test_scheduler_no_alert_for_small_change(self):
|
|
"""Scheduler does not alert when change is below threshold."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "Intel", "last_patent_count": 100},
|
|
]
|
|
|
|
mock_result = MagicMock(success=True, patent_count=105) # 5% increase
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
mock_db.store_alert.assert_not_called()
|
|
|
|
def test_scheduler_handles_analysis_failure(self):
|
|
"""Scheduler continues when one company fails analysis."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "FailCo", "last_patent_count": 50},
|
|
{"company_name": "SuccessCo", "last_patent_count": 30},
|
|
]
|
|
|
|
mock_fail_result = MagicMock(success=False, error="API timeout")
|
|
mock_ok_result = MagicMock(success=True, patent_count=35)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
# FailCo should not get updated, SuccessCo should
|
|
mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35)
|
|
|
|
def test_scheduler_handles_exception_in_analysis(self):
|
|
"""Scheduler continues even when analysis raises an exception."""
|
|
mock_db = MagicMock()
|
|
mock_db.list_tracked_companies.return_value = [
|
|
{"company_name": "CrashCo", "last_patent_count": 10},
|
|
{"company_name": "OKCo", "last_patent_count": 20},
|
|
]
|
|
|
|
mock_ok_result = MagicMock(success=True, patent_count=22)
|
|
mock_analyzer = MagicMock()
|
|
mock_analyzer._analyze_company_safe.side_effect = [
|
|
RuntimeError("unexpected error"),
|
|
mock_ok_result,
|
|
]
|
|
|
|
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
|
|
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
|
|
from SPARC.scheduler import run_scheduled_analysis
|
|
run_scheduled_analysis()
|
|
|
|
# OKCo should still be processed
|
|
mock_db.update_tracked_company.assert_called_once_with("OKCo", 22)
|