Add multi-tenant support with owner_id isolation

- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-05-19 16:04:58 +00:00
parent 3dfa651f2d
commit e37859dabc
8 changed files with 964 additions and 164 deletions
+35 -13
View File
@@ -1,12 +1,13 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
@@ -15,6 +16,27 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
@@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination:
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10")
response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
@@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3")
response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
@@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42")
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
@@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch")
client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201")
response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0")
response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel")
client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
@@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch")
response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
@@ -148,14 +170,14 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs")
client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201")
response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -165,5 +187,5 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200")
response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200