Add multi-tenant support with owner_id isolation

- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-05-19 16:04:58 +00:00
parent 3dfa651f2d
commit e37859dabc
8 changed files with 964 additions and 164 deletions
+74 -18
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints."""
from datetime import datetime
from unittest.mock import Mock
from datetime import datetime, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -16,6 +17,22 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer."""
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
return mock
def _auth_header(user_id=1, email="user@test.com", role="user"):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_id, email, role)
return {"Authorization": f"Bearer {token}"}
class TestHealthEndpoint:
"""Test health check endpoint."""
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia")
response = client.get("/analyze/nvidia", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown")
response = client.get("/analyze/unknown", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation."""
# Empty companies list
response = client.post("/analyze/batch", json={"companies": []})
response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
assert response.status_code == 422
# Too many companies
response = client.post(
"/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer):
"""Test async endpoint creates a job."""
@patch("SPARC.api._get_job_db")
def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
"""Test async endpoint creates a job with owner_id."""
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j1",
"status": "pending",
"progress": 0,
"total_companies": 2,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": 1,
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia", "amd"]},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
assert data["status"] == "pending"
assert data["total_companies"] == 2
assert data["progress"] == 0
# Verify owner_id was passed
job_db.create_job.assert_called_once()
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
class TestJobEndpoints:
"""Test job management endpoints."""
def test_get_job_not_found(self, client):
@patch("SPARC.api._get_job_db")
def test_get_job_not_found(self, mock_get_db, client):
"""Test getting nonexistent job."""
response = client.get("/jobs/nonexistent")
job_db = MagicMock()
job_db.get_job.return_value = None
mock_get_db.return_value = job_db
response = client.get("/jobs/nonexistent", headers=_auth_header())
assert response.status_code == 404
def test_list_jobs(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs(self, mock_get_db, client):
"""Test listing jobs."""
# Clear existing jobs
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs")
response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs_with_filter(self, mock_get_db, client):
"""Test listing jobs with status filter."""
response = client.get("/jobs?status=completed")
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs?status=completed", headers=_auth_header())
assert response.status_code == 200
@@ -189,7 +244,7 @@ class TestModelValidation:
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
@@ -205,7 +260,7 @@ class TestModelValidation:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
@@ -213,6 +268,7 @@ class TestModelValidation:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
headers=_auth_header(),
)
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
+1
View File
@@ -5,6 +5,7 @@ Covers issue #1655:
- GET /export/{company_name}/pdf (PDF export)
All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
Export queries are now scoped to the current user's owner_id.
"""
from datetime import datetime, timezone
+281
View File
@@ -0,0 +1,281 @@
"""Cross-tenant isolation tests for multi-tenant support.
Verifies that:
- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs
- Admin users can access all data via admin endpoints
- owner_id is correctly set on new resources
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user(user_id, email, role="user"):
return {
"id": user_id,
"email": email,
"role": role,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
USER_A = _make_user(10, "alice@test.com")
USER_B = _make_user(20, "bob@test.com")
ADMIN = _make_user(1, "admin@test.com", role="admin")
def _header_for(user):
token = create_access_token(user["id"], user["email"], user["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db():
"""Mock DB returning the correct user based on user_id."""
db = MagicMock()
def _get_user_by_id(uid):
for u in [USER_A, USER_B, ADMIN]:
if u["id"] == uid:
return u
return None
db.get_user_by_id.side_effect = _get_user_by_id
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
# ==================== Tracked Companies Isolation ====================
class TestTrackedCompanyIsolation:
"""User A's tracked companies are invisible to User B."""
def test_user_a_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User A's companies."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": USER_A["id"]},
]
response = client.get("/tracked", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"])
def test_user_b_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User B's companies."""
mock_db.list_tracked_companies.return_value = []
response = client.get("/tracked", headers=_header_for(USER_B))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"])
def test_user_a_add_sets_owner(self, client, mock_db):
"""POST /tracked sets owner_id to User A."""
mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10}
response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"])
def test_user_b_cannot_remove_user_a_company(self, client, mock_db):
"""DELETE /tracked/{name} filters by owner, so B can't remove A's company."""
mock_db.remove_tracked_company.return_value = False # not found for B
response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B))
assert response.status_code == 404
mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"])
# ==================== Job Isolation ====================
class TestJobIsolation:
"""User A's jobs are invisible to User B."""
def test_user_a_get_own_job(self, client, mock_db):
"""GET /jobs/{id} scoped to User A returns the job."""
mock_db.get_job.return_value = None # mock via _get_job_db
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = {
"job_id": "j1",
"status": "completed",
"progress": 100,
"total_companies": 1,
"completed_companies": 1,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_A))
assert response.status_code == 200
job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"])
def test_user_b_cannot_see_user_a_job(self, client, mock_db):
"""GET /jobs/{id} returns 404 when User B tries to access User A's job."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = None # not found for B's owner_id
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_B))
assert response.status_code == 404
job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"])
def test_list_jobs_scoped_to_user(self, client, mock_db):
"""GET /jobs filters by owner_id."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
def test_async_job_created_with_owner(self, client, mock_db):
"""POST /analyze/batch/async creates job with current user's owner_id."""
mock_analyzer = MagicMock()
with patch("SPARC.api._analyzer", mock_analyzer), \
patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j2",
"status": "pending",
"progress": 0,
"total_companies": 1,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia"]},
headers=_header_for(USER_A),
)
assert response.status_code == 200
create_kwargs = job_db.create_job.call_args
assert create_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Analysis Listing Isolation ====================
class TestAnalysisListIsolation:
"""GET /analyze/batch scoped to current user."""
def test_list_analyses_scoped_to_user(self, client, mock_db):
"""GET /analyze/batch passes owner_id to db.list_analyses."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/analyze/batch", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Admin Cross-Tenant Access ====================
class TestAdminCrossTenantAccess:
"""Admin endpoints return data from all tenants (no owner_id filter)."""
def test_admin_list_tracked_all_tenants(self, client, mock_db):
"""GET /admin/tracked returns all companies (no owner_id filter)."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": 10},
{"company_name": "BobCo", "owner_id": 20},
]
response = client.get("/admin/tracked", headers=_header_for(ADMIN))
assert response.status_code == 200
# Should be called without owner_id filter
mock_db.list_tracked_companies.assert_called_with()
def test_admin_list_analyses_all_tenants(self, client, mock_db):
"""GET /admin/analyses returns all analyses (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/analyses", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
# No owner_id should be passed
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_list_jobs_all_tenants(self, client, mock_db):
"""GET /admin/jobs returns all jobs (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/jobs", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_remove_tracked_any_owner(self, client, mock_db):
"""DELETE /admin/tracked/{name} removes without owner filter."""
mock_db.remove_tracked_company.return_value = True
response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN))
assert response.status_code == 200
# Called without owner_id
mock_db.remove_tracked_company.assert_called_with("SomeCo")
def test_regular_user_cannot_access_admin_analyses(self, client, mock_db):
"""Regular user gets 403 on /admin/analyses."""
response = client.get("/admin/analyses", headers=_header_for(USER_A))
assert response.status_code == 403
def test_regular_user_cannot_access_admin_jobs(self, client, mock_db):
"""Regular user gets 403 on /admin/jobs."""
response = client.get("/admin/jobs", headers=_header_for(USER_A))
assert response.status_code == 403
# ==================== Analytics Isolation ====================
class TestAnalyticsIsolation:
"""GET /analytics scoped to current user."""
def test_analytics_scoped_to_user(self, client, mock_db):
"""GET /analytics passes owner_id to db.get_analytics."""
mock_db.get_analytics.return_value = {
"total_messages": 5,
"by_company": [],
"by_type": [],
"period_days": 30,
}
response = client.get("/analytics", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"])
+35 -13
View File
@@ -1,12 +1,13 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
@@ -15,6 +16,27 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
@@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination:
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10")
response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
@@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3")
response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
@@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42")
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
@@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch")
client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201")
response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0")
response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel")
client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
@@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch")
response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
@@ -148,14 +170,14 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs")
client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201")
response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -165,5 +187,5 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200")
response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200
+71 -10
View File
@@ -1,17 +1,18 @@
"""Tests for tracked company admin endpoints and scheduler integration.
"""Tests for tracked company endpoints and scheduler integration.
Covers issue #1656:
- GET /admin/tracked (list tracked companies)
- POST /admin/tracked (add a tracked company)
- DELETE /admin/tracked/{company_name} (remove a tracked company)
Covers:
- GET /tracked (user-scoped list)
- POST /tracked (user-scoped add)
- DELETE /tracked/{company_name} (user-scoped remove)
- GET /admin/tracked (admin: all companies)
- POST /admin/tracked (admin: add)
- DELETE /admin/tracked/{company_name} (admin: remove any)
- GET /admin/alerts (list alerts)
- scheduler.run_scheduled_analysis() integration
All tests mock the database layer and use JWT auth fixtures.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch, call
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
@@ -125,7 +126,7 @@ class TestAddTrackedCompany:
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "Intel"
mock_db.add_tracked_company.assert_called_once_with("Intel")
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1)
def test_add_duplicate_returns_409(self, client, mock_db):
"""Adding an already-tracked company returns 409."""
@@ -141,7 +142,7 @@ class TestAddTrackedCompany:
assert "already tracked" in response.json()["detail"].lower()
def test_add_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot add tracked companies."""
"""Regular user cannot add tracked companies via admin endpoint."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
@@ -215,6 +216,66 @@ class TestRemoveTrackedCompany:
assert response.status_code == 403
# ---------- User-scoped tracked companies ----------
class TestUserScopedTrackedCompanies:
"""Tests for /tracked user-scoped endpoints."""
def test_user_list_tracked(self, client, mock_db):
"""Regular user can list their own tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.list_tracked_companies.return_value = [
{"company_name": "AMD", "owner_id": 2},
]
response = client.get("/tracked", headers=_user_header())
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=2)
def test_user_add_tracked(self, client, mock_db):
"""Regular user can add a company to their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.add_tracked_company.return_value = {
"company_name": "Intel",
"owner_id": 2,
}
response = client.post(
"/tracked",
json={"company_name": "Intel"},
headers=_user_header(),
)
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2)
def test_user_remove_tracked(self, client, mock_db):
"""Regular user can remove a company from their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.remove_tracked_company.return_value = True
response = client.delete("/tracked/Intel", headers=_user_header())
assert response.status_code == 200
mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2)
# ---------- GET /admin/alerts ----------
class TestListAlerts: