Add multi-tenant support with owner_id isolation

- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-05-19 16:04:58 +00:00
parent 3dfa651f2d
commit e37859dabc
8 changed files with 964 additions and 164 deletions
+74 -18
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints."""
from datetime import datetime
from unittest.mock import Mock
from datetime import datetime, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -16,6 +17,22 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer."""
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
return mock
def _auth_header(user_id=1, email="user@test.com", role="user"):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_id, email, role)
return {"Authorization": f"Bearer {token}"}
class TestHealthEndpoint:
"""Test health check endpoint."""
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia")
response = client.get("/analyze/nvidia", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown")
response = client.get("/analyze/unknown", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation."""
# Empty companies list
response = client.post("/analyze/batch", json={"companies": []})
response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
assert response.status_code == 422
# Too many companies
response = client.post(
"/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer):
"""Test async endpoint creates a job."""
@patch("SPARC.api._get_job_db")
def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
"""Test async endpoint creates a job with owner_id."""
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j1",
"status": "pending",
"progress": 0,
"total_companies": 2,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": 1,
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia", "amd"]},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
assert data["status"] == "pending"
assert data["total_companies"] == 2
assert data["progress"] == 0
# Verify owner_id was passed
job_db.create_job.assert_called_once()
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
class TestJobEndpoints:
"""Test job management endpoints."""
def test_get_job_not_found(self, client):
@patch("SPARC.api._get_job_db")
def test_get_job_not_found(self, mock_get_db, client):
"""Test getting nonexistent job."""
response = client.get("/jobs/nonexistent")
job_db = MagicMock()
job_db.get_job.return_value = None
mock_get_db.return_value = job_db
response = client.get("/jobs/nonexistent", headers=_auth_header())
assert response.status_code == 404
def test_list_jobs(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs(self, mock_get_db, client):
"""Test listing jobs."""
# Clear existing jobs
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs")
response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs_with_filter(self, mock_get_db, client):
"""Test listing jobs with status filter."""
response = client.get("/jobs?status=completed")
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs?status=completed", headers=_auth_header())
assert response.status_code == 200
@@ -189,7 +244,7 @@ class TestModelValidation:
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
@@ -205,7 +260,7 @@ class TestModelValidation:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
@@ -213,6 +268,7 @@ class TestModelValidation:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
headers=_auth_header(),
)
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]