forked from 0xWheatyz/SPARC
Add multi-tenant support with owner_id isolation
- Add owner_id (FK to users) column to llm_messages, jobs, and
tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
assertions
- Add comprehensive cross-tenant isolation test suite
(tests/test_multi_tenant.py)
Closes leeworks-agents/SPARC#1677
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+74
-18
@@ -1,12 +1,13 @@
|
||||
"""Tests for FastAPI web service endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.auth import create_access_token
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
|
||||
@@ -16,6 +17,22 @@ def client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
"""Mock the database client used by auth endpoints."""
|
||||
db = MagicMock()
|
||||
db.get_user_by_id.return_value = {
|
||||
"id": 1,
|
||||
"email": "user@test.com",
|
||||
"role": "user",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
with patch("SPARC.api.get_db_client", return_value=db), \
|
||||
patch("SPARC.auth.get_db_client", return_value=db):
|
||||
yield db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analyzer(mocker):
|
||||
"""Mock the global analyzer."""
|
||||
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
|
||||
return mock
|
||||
|
||||
|
||||
def _auth_header(user_id=1, email="user@test.com", role="user"):
|
||||
"""Create an Authorization header with a valid access token."""
|
||||
token = create_access_token(user_id, email, role)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test health check endpoint."""
|
||||
|
||||
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/nvidia")
|
||||
response = client.get("/analyze/nvidia", headers=_auth_header())
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/unknown")
|
||||
response = client.get("/analyze/unknown", headers=_auth_header())
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia", "amd"], "max_workers": 2},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
|
||||
def test_batch_analysis_validation(self, client):
|
||||
"""Test batch analysis request validation."""
|
||||
# Empty companies list
|
||||
response = client.post("/analyze/batch", json={"companies": []})
|
||||
response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
|
||||
assert response.status_code == 422
|
||||
|
||||
# Too many companies
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": [f"company{i}" for i in range(25)]},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia"], "max_workers": 10},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
|
||||
class TestAsyncBatchEndpoint:
|
||||
"""Test async batch analysis endpoint."""
|
||||
|
||||
def test_async_batch_creates_job(self, client, mock_analyzer):
|
||||
"""Test async endpoint creates a job."""
|
||||
@patch("SPARC.api._get_job_db")
|
||||
def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
|
||||
"""Test async endpoint creates a job with owner_id."""
|
||||
job_db = MagicMock()
|
||||
job_db.create_job.return_value = {
|
||||
"job_id": "j1",
|
||||
"status": "pending",
|
||||
"progress": 0,
|
||||
"total_companies": 2,
|
||||
"completed_companies": 0,
|
||||
"result_json": None,
|
||||
"error": None,
|
||||
"owner_id": 1,
|
||||
}
|
||||
mock_get_db.return_value = job_db
|
||||
|
||||
response = client.post(
|
||||
"/analyze/batch/async",
|
||||
json={"companies": ["nvidia", "amd"]},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
|
||||
assert data["status"] == "pending"
|
||||
assert data["total_companies"] == 2
|
||||
assert data["progress"] == 0
|
||||
# Verify owner_id was passed
|
||||
job_db.create_job.assert_called_once()
|
||||
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
|
||||
|
||||
|
||||
class TestJobEndpoints:
|
||||
"""Test job management endpoints."""
|
||||
|
||||
def test_get_job_not_found(self, client):
|
||||
@patch("SPARC.api._get_job_db")
|
||||
def test_get_job_not_found(self, mock_get_db, client):
|
||||
"""Test getting nonexistent job."""
|
||||
response = client.get("/jobs/nonexistent")
|
||||
job_db = MagicMock()
|
||||
job_db.get_job.return_value = None
|
||||
mock_get_db.return_value = job_db
|
||||
|
||||
response = client.get("/jobs/nonexistent", headers=_auth_header())
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_list_jobs(self, client, mocker):
|
||||
@patch("SPARC.api._get_job_db")
|
||||
def test_list_jobs(self, mock_get_db, client):
|
||||
"""Test listing jobs."""
|
||||
# Clear existing jobs
|
||||
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
|
||||
job_db = MagicMock()
|
||||
job_db.list_jobs.return_value = []
|
||||
mock_get_db.return_value = job_db
|
||||
|
||||
response = client.get("/jobs")
|
||||
response = client.get("/jobs", headers=_auth_header())
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
def test_list_jobs_with_filter(self, client, mocker):
|
||||
@patch("SPARC.api._get_job_db")
|
||||
def test_list_jobs_with_filter(self, mock_get_db, client):
|
||||
"""Test listing jobs with status filter."""
|
||||
response = client.get("/jobs?status=completed")
|
||||
job_db = MagicMock()
|
||||
job_db.list_jobs.return_value = []
|
||||
mock_get_db.return_value = job_db
|
||||
|
||||
response = client.get("/jobs?status=completed", headers=_auth_header())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -189,7 +244,7 @@ class TestModelValidation:
|
||||
|
||||
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
|
||||
"""GET /analyze/{company} with unsupported model returns 400."""
|
||||
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
|
||||
response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported model" in response.json()["detail"]
|
||||
|
||||
@@ -205,7 +260,7 @@ class TestModelValidation:
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
|
||||
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
|
||||
@@ -213,6 +268,7 @@ class TestModelValidation:
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
|
||||
headers=_auth_header(),
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported model" in response.json()["detail"]
|
||||
|
||||
Reference in New Issue
Block a user