"""Cross-tenant isolation tests for multi-tenant support. Verifies that: - User A cannot read, update, or delete User B's analyses, tracked companies, or jobs - Admin users can access all data via admin endpoints - owner_id is correctly set on new resources """ from datetime import datetime, timezone from unittest.mock import MagicMock, Mock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app from SPARC.auth import create_access_token @pytest.fixture def client(): """Create test client.""" return TestClient(app) def _make_user(user_id, email, role="user"): return { "id": user_id, "email": email, "role": role, "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), } USER_A = _make_user(10, "alice@test.com") USER_B = _make_user(20, "bob@test.com") ADMIN = _make_user(1, "admin@test.com", role="admin") def _header_for(user): token = create_access_token(user["id"], user["email"], user["role"]) return {"Authorization": f"Bearer {token}"} @pytest.fixture(autouse=True) def mock_db(): """Mock DB returning the correct user based on user_id.""" db = MagicMock() def _get_user_by_id(uid): for u in [USER_A, USER_B, ADMIN]: if u["id"] == uid: return u return None db.get_user_by_id.side_effect = _get_user_by_id with patch("SPARC.api.get_db_client", return_value=db), \ patch("SPARC.auth.get_db_client", return_value=db): yield db # ==================== Tracked Companies Isolation ==================== class TestTrackedCompanyIsolation: """User A's tracked companies are invisible to User B.""" def test_user_a_list_scoped_to_own(self, client, mock_db): """GET /tracked returns only User A's companies.""" mock_db.list_tracked_companies.return_value = [ {"company_name": "AliceCo", "owner_id": USER_A["id"]}, ] response = client.get("/tracked", headers=_header_for(USER_A)) assert response.status_code == 200 mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"]) def test_user_b_list_scoped_to_own(self, client, mock_db): """GET /tracked returns only User B's companies.""" mock_db.list_tracked_companies.return_value = [] response = client.get("/tracked", headers=_header_for(USER_B)) assert response.status_code == 200 mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"]) def test_user_a_add_sets_owner(self, client, mock_db): """POST /tracked sets owner_id to User A.""" mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10} response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A)) assert response.status_code == 200 mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"]) def test_user_b_cannot_remove_user_a_company(self, client, mock_db): """DELETE /tracked/{name} filters by owner, so B can't remove A's company.""" mock_db.remove_tracked_company.return_value = False # not found for B response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B)) assert response.status_code == 404 mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"]) # ==================== Job Isolation ==================== class TestJobIsolation: """User A's jobs are invisible to User B.""" def test_user_a_get_own_job(self, client, mock_db): """GET /jobs/{id} scoped to User A returns the job.""" mock_db.get_job.return_value = None # mock via _get_job_db with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.get_job.return_value = { "job_id": "j1", "status": "completed", "progress": 100, "total_companies": 1, "completed_companies": 1, "result_json": None, "error": None, "owner_id": USER_A["id"], } mock_get_db.return_value = job_db response = client.get("/jobs/j1", headers=_header_for(USER_A)) assert response.status_code == 200 job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"]) def test_user_b_cannot_see_user_a_job(self, client, mock_db): """GET /jobs/{id} returns 404 when User B tries to access User A's job.""" with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.get_job.return_value = None # not found for B's owner_id mock_get_db.return_value = job_db response = client.get("/jobs/j1", headers=_header_for(USER_B)) assert response.status_code == 404 job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"]) def test_list_jobs_scoped_to_user(self, client, mock_db): """GET /jobs filters by owner_id.""" with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.list_jobs.return_value = [] mock_get_db.return_value = job_db response = client.get("/jobs", headers=_header_for(USER_A)) assert response.status_code == 200 call_kwargs = job_db.list_jobs.call_args assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] def test_async_job_created_with_owner(self, client, mock_db): """POST /analyze/batch/async creates job with current user's owner_id.""" mock_analyzer = MagicMock() with patch("SPARC.api._analyzer", mock_analyzer), \ patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.create_job.return_value = { "job_id": "j2", "status": "pending", "progress": 0, "total_companies": 1, "completed_companies": 0, "result_json": None, "error": None, "owner_id": USER_A["id"], } mock_get_db.return_value = job_db response = client.post( "/analyze/batch/async", json={"companies": ["nvidia"]}, headers=_header_for(USER_A), ) assert response.status_code == 200 create_kwargs = job_db.create_job.call_args assert create_kwargs.kwargs.get("owner_id") == USER_A["id"] # ==================== Analysis Listing Isolation ==================== class TestAnalysisListIsolation: """GET /analyze/batch scoped to current user.""" def test_list_analyses_scoped_to_user(self, client, mock_db): """GET /analyze/batch passes owner_id to db.list_analyses.""" with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.list_analyses.return_value = [] mock_get_db.return_value = job_db response = client.get("/analyze/batch", headers=_header_for(USER_A)) assert response.status_code == 200 call_kwargs = job_db.list_analyses.call_args assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] # ==================== Admin Cross-Tenant Access ==================== class TestAdminCrossTenantAccess: """Admin endpoints return data from all tenants (no owner_id filter).""" def test_admin_list_tracked_all_tenants(self, client, mock_db): """GET /admin/tracked returns all companies (no owner_id filter).""" mock_db.list_tracked_companies.return_value = [ {"company_name": "AliceCo", "owner_id": 10}, {"company_name": "BobCo", "owner_id": 20}, ] response = client.get("/admin/tracked", headers=_header_for(ADMIN)) assert response.status_code == 200 # Should be called without owner_id filter mock_db.list_tracked_companies.assert_called_with() def test_admin_list_analyses_all_tenants(self, client, mock_db): """GET /admin/analyses returns all analyses (no owner_id filter).""" with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.list_analyses.return_value = [] mock_get_db.return_value = job_db response = client.get("/admin/analyses", headers=_header_for(ADMIN)) assert response.status_code == 200 call_kwargs = job_db.list_analyses.call_args # No owner_id should be passed assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None def test_admin_list_jobs_all_tenants(self, client, mock_db): """GET /admin/jobs returns all jobs (no owner_id filter).""" with patch("SPARC.api._get_job_db") as mock_get_db: job_db = MagicMock() job_db.list_jobs.return_value = [] mock_get_db.return_value = job_db response = client.get("/admin/jobs", headers=_header_for(ADMIN)) assert response.status_code == 200 call_kwargs = job_db.list_jobs.call_args assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None def test_admin_remove_tracked_any_owner(self, client, mock_db): """DELETE /admin/tracked/{name} removes without owner filter.""" mock_db.remove_tracked_company.return_value = True response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN)) assert response.status_code == 200 # Called without owner_id mock_db.remove_tracked_company.assert_called_with("SomeCo") def test_regular_user_cannot_access_admin_analyses(self, client, mock_db): """Regular user gets 403 on /admin/analyses.""" response = client.get("/admin/analyses", headers=_header_for(USER_A)) assert response.status_code == 403 def test_regular_user_cannot_access_admin_jobs(self, client, mock_db): """Regular user gets 403 on /admin/jobs.""" response = client.get("/admin/jobs", headers=_header_for(USER_A)) assert response.status_code == 403 # ==================== Analytics Isolation ==================== class TestAnalyticsIsolation: """GET /analytics scoped to current user.""" def test_analytics_scoped_to_user(self, client, mock_db): """GET /analytics passes owner_id to db.get_analytics.""" mock_db.get_analytics.return_value = { "total_messages": 5, "by_company": [], "by_type": [], "period_days": 30, } response = client.get("/analytics", headers=_header_for(USER_A)) assert response.status_code == 200 mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"])