Compare commits

..

2 Commits

Author SHA1 Message Date
agent-company e37859dabc Add multi-tenant support with owner_id isolation
- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 16:04:58 +00:00
agent-company 3dfa651f2d Add rate limiting dashboard to admin panel
- Enhance GET /admin/rate-limits with per-IP breakdown, 24h throttled
  count, and hourly time-series of rejected requests
- Add _rejected_log deque for time-series tracking of throttled requests
- Add AdminRateLimits React page with auto-refresh (configurable 15s/30s/1m),
  summary cards, throttled-over-time bar chart, endpoint table, per-IP table
- Add TypeScript types (RateLimitStatsResponse) and adminApi.getRateLimits()
- Wire up /admin/rate-limits route and nav link (admin-only)
- Expand unit tests to 10 cases: auth, empty state, per-IP breakdown,
  throttled_24h count, time-series structure, response shape contract

Closes leeworks-agents/SPARC#1686

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:39:45 +00:00
13 changed files with 1363 additions and 171 deletions
+251 -98
View File
@@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -248,6 +249,9 @@ app.state.limiter = limiter
# In-memory rate limit statistics # In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {} _rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint.""" """Record a request against a rate-limited endpoint."""
@@ -262,6 +266,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) ->
_rate_limit_stats[key]["total_requests"] += 1 _rate_limit_stats[key]["total_requests"] += 1
if rejected: if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1 _rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats: if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0} ip_stats[ip] = {"total": 0, "rejected": 0}
@@ -465,11 +474,46 @@ class TrackCompanyRequest(BaseModel):
company_name: CompanyName = Field(...) company_name: CompanyName = Field(...)
@app.get("/tracked", tags=["Tracked Companies"])
async def list_my_tracked_companies(
current_user: UserResponse = Depends(get_current_user),
):
"""List tracked companies for the current user."""
db = get_db_client()
return db.list_tracked_companies(owner_id=current_user.id)
@app.post("/tracked", tags=["Tracked Companies"])
async def add_my_tracked_company(
request: TrackCompanyRequest,
current_user: UserResponse = Depends(get_current_user),
):
"""Add a company to the current user's tracked list."""
db = get_db_client()
result = db.add_tracked_company(request.company_name, owner_id=current_user.id)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/tracked/{company_name}", tags=["Tracked Companies"])
async def remove_my_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
current_user: UserResponse = Depends(get_current_user),
):
"""Remove a company from the current user's tracked list."""
db = get_db_client()
removed = db.remove_tracked_company(company_name, owner_id=current_user.id)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/tracked", tags=["Admin"]) @app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies( async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""List all tracked companies (admin only).""" """List all tracked companies across all users (admin only)."""
db = get_db_client() db = get_db_client()
return db.list_tracked_companies() return db.list_tracked_companies()
@@ -477,11 +521,11 @@ async def list_tracked_companies(
@app.post("/admin/tracked", tags=["Admin"]) @app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company( async def add_tracked_company(
request: TrackCompanyRequest, request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin), current_admin: UserResponse = Depends(get_current_admin),
): ):
"""Add a company to the tracked list (admin only).""" """Add a company to the tracked list (admin only, owned by admin)."""
db = get_db_client() db = get_db_client()
result = db.add_tracked_company(request.company_name) result = db.add_tracked_company(request.company_name, owner_id=current_admin.id)
if not result: if not result:
raise HTTPException(status_code=409, detail="Company already tracked") raise HTTPException(status_code=409, detail="Company already tracked")
return result return result
@@ -492,7 +536,7 @@ async def remove_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""Remove a company from the tracked list (admin only).""" """Remove a company from the tracked list (admin only, any owner)."""
db = get_db_client() db = get_db_client()
removed = db.remove_tracked_company(company_name) removed = db.remove_tracked_company(company_name)
if not removed: if not removed:
@@ -507,10 +551,12 @@ async def get_rate_limit_stats(
"""Get rate limit status and usage statistics (admin only). """Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics Returns current rate limit configuration and request statistics
for all rate-limited endpoints. for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns: Returns:
List of rate limit stats per endpoint with total/rejected counts Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
""" """
rate_limits_config = { rate_limits_config = {
"/auth/register": {"limit": "5/minute"}, "/auth/register": {"limit": "5/minute"},
@@ -520,14 +566,45 @@ async def get_rate_limit_stats(
results = [] results = []
for endpoint, conf in rate_limits_config.items(): for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {}) stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({ results.append({
"endpoint": endpoint, "endpoint": endpoint,
"limit": conf["limit"], "limit": conf["limit"],
"total_requests": stats.get("total_requests", 0), "total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0), "rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
}) })
return {"rate_limits": results} # Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"]) @app.get("/admin/alerts", tags=["Admin"])
@@ -540,17 +617,86 @@ async def list_alerts(
return db.list_alerts(limit=limit) return db.list_alerts(limit=limit)
# ============== Admin-Scoped Data Endpoints ==============
@app.get("/admin/analyses", response_model=PaginatedAnalysisResponse, tags=["Admin"])
async def admin_list_analyses(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all analysis results across all users (admin only)."""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.get("/admin/jobs", response_model=PaginatedJobsResponse, tags=["Admin"])
async def admin_list_jobs(
status: Annotated[
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all jobs across all users (admin only)."""
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
has_next = len(job_rows) > limit
if has_next:
job_rows = job_rows[:limit]
items = [_job_row_to_status(row) for row in job_rows]
next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
# ============== Analytics Endpoint ============== # ============== Analytics Endpoint ==============
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"]) @app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
async def get_analytics( async def get_analytics(
days: int = Query(default=30, ge=1, le=365), days: int = Query(default=30, ge=1, le=365),
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get analytics data (authenticated users only).""" """Get analytics data scoped to the current user."""
db = get_db_client() db = get_db_client()
analytics = db.get_analytics(days=days) analytics = db.get_analytics(days=days, owner_id=current_user.id)
return AnalyticsResponse( return AnalyticsResponse(
total_messages=analytics["total_messages"], total_messages=analytics["total_messages"],
@@ -603,9 +749,9 @@ async def list_models():
@app.get("/analytics/trends", tags=["Analytics"]) @app.get("/analytics/trends", tags=["Analytics"])
async def get_analytics_trends( async def get_analytics_trends(
days: int = Query(default=90, ge=7, le=365), days: int = Query(default=90, ge=7, le=365),
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get trend data for patent analysis over time. """Get trend data for patent analysis over time (scoped to current user).
Returns two datasets: Returns two datasets:
- ``by_month``: analysis count per company per month - ``by_month``: analysis count per company per month
@@ -619,11 +765,14 @@ async def get_analytics_trends(
""" """
db = get_db_client() db = get_db_client()
owner_filter = " AND owner_id = %s" if current_user else ""
owner_params = (current_user.id,) if current_user else ()
with db.get_conn() as conn: with db.get_conn() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
# Analyses per company per month # Analyses per company per month
cur.execute( cur.execute(
""" f"""
SELECT SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month, TO_CHAR(timestamp, 'YYYY-MM') AS month,
company_name, company_name,
@@ -632,16 +781,17 @@ async def get_analytics_trends(
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE AND is_cached = FALSE
AND company_name IS NOT NULL AND company_name IS NOT NULL
{owner_filter}
GROUP BY month, company_name GROUP BY month, company_name
ORDER BY month ORDER BY month
""", """,
(days,), (days, *owner_params),
) )
by_month_rows = cur.fetchall() by_month_rows = cur.fetchall()
# Analysis type distribution per month # Analysis type distribution per month
cur.execute( cur.execute(
""" f"""
SELECT SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month, TO_CHAR(timestamp, 'YYYY-MM') AS month,
analysis_type, analysis_type,
@@ -649,10 +799,11 @@ async def get_analytics_trends(
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE AND is_cached = FALSE
{owner_filter}
GROUP BY month, analysis_type GROUP BY month, analysis_type
ORDER BY month ORDER BY month
""", """,
(days,), (days, *owner_params),
) )
by_type_rows = cur.fetchall() by_type_rows = cur.fetchall()
@@ -678,9 +829,9 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"]) @app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv( async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a CSV file. """Export analysis results for a company as a CSV file (scoped to current user).
Returns all stored analysis records for the given company, including Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp. analysis type, model used, response text, and timestamp.
@@ -695,7 +846,7 @@ async def export_company_csv(
import io import io
db = get_db_client() db = get_db_client()
# Query all non-cached analysis results for this company # Query all non-cached analysis results for this company owned by current user
with db.get_conn() as conn: with db.get_conn() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute( cur.execute(
@@ -703,9 +854,10 @@ async def export_company_csv(
SELECT company_name, analysis_type, model, response, timestamp SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC ORDER BY timestamp DESC
""", """,
(company_name,), (company_name, current_user.id),
) )
rows = cur.fetchall() rows = cur.fetchall()
@@ -730,9 +882,9 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"]) @app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf( async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a formatted PDF report. """Export analysis results for a company as a formatted PDF report (scoped to current user).
Returns all stored analysis records for the given company, including Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted analysis type, model used, response text, and timestamp, formatted
@@ -766,9 +918,10 @@ async def export_company_pdf(
SELECT company_name, analysis_type, model, response, timestamp SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC ORDER BY timestamp DESC
""", """,
(company_name,), (company_name, current_user.id),
) )
rows = cur.fetchall() rows = cur.fetchall()
@@ -897,68 +1050,6 @@ async def health_check():
) )
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns:
Analysis results including patent count, AI insights, and success status
"""
_validate_model(model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get( @app.get(
"/analyze/batch", "/analyze/batch",
response_model=PaginatedAnalysisResponse, response_model=PaginatedAnalysisResponse,
@@ -974,9 +1065,9 @@ async def list_analysis_results(
str | None, str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"), Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None, ] = None,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""List stored analysis results with cursor-based pagination. """List stored analysis results with cursor-based pagination (scoped to current user).
Returns past analysis results ordered by timestamp descending. Use Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response ``limit`` to control page size (default 50, max 200). The response
@@ -993,7 +1084,7 @@ async def list_analysis_results(
Paginated list of analysis results Paginated list of analysis results
""" """
db = _get_job_db() db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(rows) > limit has_next = len(rows) > limit
if has_next: if has_next:
@@ -1043,6 +1134,68 @@ async def analyze_companies_batch(
return _convert_batch_result(result) return _convert_batch_result(result)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns:
Analysis results including patent count, AI insights, and success status
"""
_validate_model(model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result)
def _get_job_db() -> "DatabaseClient": def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence.""" """Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
@@ -1129,7 +1282,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
async def analyze_companies_async( async def analyze_companies_async(
request: BatchAnalysisRequest, request: BatchAnalysisRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Start an asynchronous batch analysis job. """Start an asynchronous batch analysis job.
@@ -1149,7 +1302,7 @@ async def analyze_companies_async(
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
db = _get_job_db() db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) job_row = db.create_job(job_id=job_id, total_companies=len(request.companies), owner_id=current_user.id)
background_tasks.add_task( background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers, request.model _run_batch_job, job_id, request.companies, request.max_workers, request.model
@@ -1161,9 +1314,9 @@ async def analyze_companies_async(
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
async def get_job_status( async def get_job_status(
job_id: str, job_id: str,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get the status of a background analysis job. """Get the status of a background analysis job (scoped to current user).
Args: Args:
job_id: The job ID returned from the async batch endpoint job_id: The job ID returned from the async batch endpoint
@@ -1172,7 +1325,7 @@ async def get_job_status(
Current job status including progress and results when complete Current job status including progress and results when complete
""" """
db = _get_job_db() db = _get_job_db()
job_row = db.get_job(job_id) job_row = db.get_job(job_id, owner_id=current_user.id)
if not job_row: if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found") raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
@@ -1191,9 +1344,9 @@ async def list_jobs(
str | None, str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"), Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None, ] = None,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""List analysis jobs with cursor-based pagination. """List analysis jobs with cursor-based pagination (scoped to current user).
Pass ``limit`` to control page size. The response includes a ``next_cursor`` Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page. field; pass it back as the ``cursor`` query parameter to fetch the next page.
@@ -1212,7 +1365,7 @@ async def list_jobs(
""" """
db = _get_job_db() db = _get_job_db()
# Fetch one extra to determine if there is a next page # Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(job_rows) > limit has_next = len(job_rows) > limit
if has_next: if has_next:
+165 -29
View File
@@ -196,7 +196,7 @@ class DatabaseClient:
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies ( CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL, company_name VARCHAR(255) NOT NULL,
last_patent_count INTEGER DEFAULT 0, last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP, last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -221,6 +221,68 @@ class DatabaseClient:
ON alerts(company_name) ON alerts(company_name)
""") """)
# ---- Multi-tenant: add owner_id columns if missing ----
cursor.execute("""
DO $$
BEGIN
-- llm_messages.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- jobs.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- tracked_companies.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# Indexes for owner_id filtering
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# Drop the old unique constraint on company_name alone (if it exists)
# and replace with a per-owner unique constraint so different users
# can track the same company independently.
cursor.execute("""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
self.conn.commit() self.conn.commit()
@staticmethod @staticmethod
@@ -289,6 +351,7 @@ class DatabaseClient:
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
token_usage: Optional[Dict] = None, token_usage: Optional[Dict] = None,
is_cached: bool = False, is_cached: bool = False,
owner_id: Optional[int] = None,
) -> int: ) -> int:
"""Store an LLM message exchange in the database. """Store an LLM message exchange in the database.
@@ -301,6 +364,7 @@ class DatabaseClient:
metadata: Additional metadata as dict metadata: Additional metadata as dict
token_usage: Token usage information token_usage: Token usage information
is_cached: Whether this response was served from cache is_cached: Whether this response was served from cache
owner_id: ID of the user who owns this record
Returns: Returns:
The ID of the inserted record The ID of the inserted record
@@ -312,8 +376,8 @@ class DatabaseClient:
cursor.execute( cursor.execute(
""" """
INSERT INTO llm_messages INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached, owner_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id RETURNING id
""", """,
( (
@@ -326,6 +390,7 @@ class DatabaseClient:
json.dumps(metadata) if metadata else None, json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None, json.dumps(token_usage) if token_usage else None,
is_cached, is_cached,
owner_id,
), ),
) )
@@ -340,6 +405,7 @@ class DatabaseClient:
analysis_type: Optional[str] = None, analysis_type: Optional[str] = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
owner_id: Optional[int] = None,
) -> List[Dict]: ) -> List[Dict]:
"""Retrieve messages from the database. """Retrieve messages from the database.
@@ -348,6 +414,7 @@ class DatabaseClient:
analysis_type: Filter by analysis type analysis_type: Filter by analysis type
limit: Maximum number of records to return limit: Maximum number of records to return
offset: Number of records to skip offset: Number of records to skip
owner_id: Filter by owner (None returns all, for admin use)
Returns: Returns:
List of message dictionaries List of message dictionaries
@@ -355,6 +422,10 @@ class DatabaseClient:
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
if company_name: if company_name:
query += " AND company_name = %s" query += " AND company_name = %s"
params.append(company_name) params.append(company_name)
@@ -376,6 +447,7 @@ class DatabaseClient:
company_name: Optional[str] = None, company_name: Optional[str] = None,
limit: int = 50, limit: int = 50,
cursor: Optional[str] = None, cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]: ) -> List[Dict]:
"""List analysis results with cursor-based pagination. """List analysis results with cursor-based pagination.
@@ -383,6 +455,7 @@ class DatabaseClient:
company_name: Optional filter by company name. company_name: Optional filter by company name.
limit: Maximum number of records to return. limit: Maximum number of records to return.
cursor: Opaque cursor (``timestamp|id``) from a previous response. cursor: Opaque cursor (``timestamp|id``) from a previous response.
owner_id: Filter by owner (None returns all, for admin use).
Returns: Returns:
List of analysis dicts ordered by timestamp descending. List of analysis dicts ordered by timestamp descending.
@@ -390,6 +463,10 @@ class DatabaseClient:
conditions: list[str] = ["is_cached = FALSE"] conditions: list[str] = ["is_cached = FALSE"]
params: list = [] params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if company_name: if company_name:
conditions.append("LOWER(company_name) = LOWER(%s)") conditions.append("LOWER(company_name) = LOWER(%s)")
params.append(company_name) params.append(company_name)
@@ -413,52 +490,62 @@ class DatabaseClient:
cur.execute(query, params) cur.execute(query, params)
return [dict(row) for row in cur.fetchall()] return [dict(row) for row in cur.fetchall()]
def get_analytics(self, days: int = 30) -> Dict: def get_analytics(self, days: int = 30, owner_id: Optional[int] = None) -> Dict:
"""Get analytics on message usage. """Get analytics on message usage.
Args: Args:
days: Number of days to look back days: Number of days to look back
owner_id: Filter by owner (None returns all, for admin use)
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
owner_filter = ""
owner_params: list = []
if owner_id is not None:
owner_filter = " AND owner_id = %s"
owner_params = [owner_id]
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages # Total messages
cursor.execute( cursor.execute(
""" f"""
SELECT COUNT(*) as total_messages SELECT COUNT(*) as total_messages
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
""", """,
(days,), (days, *owner_params),
) )
total = cursor.fetchone()["total_messages"] total = cursor.fetchone()["total_messages"]
# Messages by company # Messages by company
cursor.execute( cursor.execute(
""" f"""
SELECT company_name, COUNT(*) as count SELECT company_name, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY company_name GROUP BY company_name
ORDER BY count DESC ORDER BY count DESC
LIMIT 10 LIMIT 10
""", """,
(days,), (days, *owner_params),
) )
by_company = cursor.fetchall() by_company = cursor.fetchall()
# Messages by type # Messages by type
cursor.execute( cursor.execute(
""" f"""
SELECT analysis_type, COUNT(*) as count SELECT analysis_type, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY analysis_type GROUP BY analysis_type
ORDER BY count DESC ORDER BY count DESC
""", """,
(days,), (days, *owner_params),
) )
by_type = cursor.fetchall() by_type = cursor.fetchall()
@@ -556,12 +643,14 @@ class DatabaseClient:
self, self,
job_id: str, job_id: str,
total_companies: int, total_companies: int,
owner_id: Optional[int] = None,
) -> Dict: ) -> Dict:
"""Create a new job record. """Create a new job record.
Args: Args:
job_id: Unique job identifier job_id: Unique job identifier
total_companies: Number of companies in the batch total_companies: Number of companies in the batch
owner_id: ID of the user who owns this job
Returns: Returns:
Job dict Job dict
@@ -570,11 +659,11 @@ class DatabaseClient:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies, owner_id)
VALUES (%s, 'pending', 0, %s, 0) VALUES (%s, 'pending', 0, %s, 0, %s)
RETURNING * RETURNING *
""", """,
(job_id, total_companies), (job_id, total_companies, owner_id),
) )
job = cursor.fetchone() job = cursor.fetchone()
conn.commit() conn.commit()
@@ -627,11 +716,22 @@ class DatabaseClient:
conn.commit() conn.commit()
return dict(job) if job else None return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]: def get_job(self, job_id: str, owner_id: Optional[int] = None) -> Optional[Dict]:
"""Get a job by ID.""" """Get a job by ID.
Args:
job_id: Job identifier.
owner_id: When provided, only return the job if it belongs to this owner.
"""
query = "SELECT * FROM jobs WHERE job_id = %s"
params: list = [job_id]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) cursor.execute(query, params)
job = cursor.fetchone() job = cursor.fetchone()
return dict(job) if job else None return dict(job) if job else None
@@ -640,6 +740,7 @@ class DatabaseClient:
status: Optional[str] = None, status: Optional[str] = None,
limit: int = 10, limit: int = 10,
cursor: Optional[str] = None, cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]: ) -> List[Dict]:
"""List jobs with optional status filter and cursor-based pagination. """List jobs with optional status filter and cursor-based pagination.
@@ -649,6 +750,7 @@ class DatabaseClient:
cursor: Opaque cursor (``created_at|job_id``) from a previous cursor: Opaque cursor (``created_at|job_id``) from a previous
response. When provided, only jobs older than the cursor are response. When provided, only jobs older than the cursor are
returned. returned.
owner_id: Filter by owner (None returns all, for admin use).
Returns: Returns:
List of job dicts ordered by created_at descending. List of job dicts ordered by created_at descending.
@@ -656,6 +758,10 @@ class DatabaseClient:
conditions: list[str] = [] conditions: list[str] = []
params: list = [] params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if status: if status:
conditions.append("status = %s") conditions.append("status = %s")
params.append(status) params.append(status)
@@ -902,14 +1008,21 @@ class DatabaseClient:
# Tracked Companies Methods # Tracked Companies Methods
def add_tracked_company(self, company_name: str) -> Optional[Dict]: def add_tracked_company(
"""Add a company to the tracking list.""" self, company_name: str, owner_id: Optional[int] = None
) -> Optional[Dict]:
"""Add a company to the tracking list.
Args:
company_name: Company name to track.
owner_id: ID of the user who owns this tracked company.
"""
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try: try:
cursor.execute( cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", "INSERT INTO tracked_companies (company_name, owner_id) VALUES (%s, %s) RETURNING *",
(company_name,), (company_name, owner_id),
) )
row = cursor.fetchone() row = cursor.fetchone()
conn.commit() conn.commit()
@@ -918,22 +1031,45 @@ class DatabaseClient:
conn.rollback() conn.rollback()
return None return None
def remove_tracked_company(self, company_name: str) -> bool: def remove_tracked_company(
"""Remove a company from the tracking list.""" self, company_name: str, owner_id: Optional[int] = None
) -> bool:
"""Remove a company from the tracking list.
Args:
company_name: Company name to remove.
owner_id: When provided, only remove if owned by this user.
"""
query = "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)"
params: list = [company_name]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute( cursor.execute(query, params)
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]: def list_tracked_companies(
"""List all tracked companies.""" self, owner_id: Optional[int] = None
) -> List[Dict]:
"""List tracked companies.
Args:
owner_id: Filter by owner (None returns all, for admin/scheduler use).
"""
query = "SELECT * FROM tracked_companies"
params: list = []
if owner_id is not None:
query += " WHERE owner_id = %s"
params.append(owner_id)
query += " ORDER BY company_name"
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def update_tracked_company( def update_tracked_company(
+9
View File
@@ -11,6 +11,7 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics'; import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About'; import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers'; import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare'; import { Compare } from './pages/Compare';
const queryClient = new QueryClient({ const queryClient = new QueryClient({
@@ -56,6 +57,14 @@ function App() {
</ProtectedRoute> </ProtectedRoute>
} }
/> />
<Route
path="/admin/rate-limits"
element={
<ProtectedRoute requireAdmin>
<AdminRateLimits />
</ProtectedRoute>
}
/>
</Route> </Route>
{/* Default redirect */} {/* Default redirect */}
+31
View File
@@ -201,6 +201,32 @@ export const analyticsApi = {
}, },
}; };
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API // Admin API
export const adminApi = { export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => { listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -216,6 +242,11 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => { deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`); await api.delete(`/admin/users/${userId}`);
}, },
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
}; };
export default api; export default api;
+2 -1
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext'; import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext'; import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react'; import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react';
export function Layout() { export function Layout() {
const { user, isAdmin, logout } = useAuth(); const { user, isAdmin, logout } = useAuth();
@@ -23,6 +23,7 @@ export function Layout() {
if (isAdmin) { if (isAdmin) {
navItems.push({ to: '/admin/users', icon: Users, label: 'Users' }); navItems.push({ to: '/admin/users', icon: Users, label: 'Users' });
navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' });
} }
return ( return (
+240
View File
@@ -0,0 +1,240 @@
import { useState } from 'react';
import { useQuery } from '@tanstack/react-query';
import { adminApi } from '../api/client';
import type { RateLimitStatsResponse } from '../api/client';
import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react';
const REFRESH_OPTIONS = [
{ label: '15s', value: 15_000 },
{ label: '30s', value: 30_000 },
{ label: '1m', value: 60_000 },
{ label: 'Off', value: 0 },
];
export function AdminRateLimits() {
const [refreshInterval, setRefreshInterval] = useState(30_000);
const { data, isLoading, isError, dataUpdatedAt } = useQuery<RateLimitStatsResponse>({
queryKey: ['admin-rate-limits'],
queryFn: () => adminApi.getRateLimits(),
refetchInterval: refreshInterval || false,
});
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div>
);
}
if (isError) {
return (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load rate limit statistics.</span>
</div>
);
}
const maxThrottledCount = data?.throttled_over_time?.length
? Math.max(...data.throttled_over_time.map((b) => b.count))
: 0;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between flex-wrap gap-4">
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Rate Limiting Dashboard
</h2>
<p className="text-text-secondary">Monitor API rate limits and throttled requests.</p>
</div>
<div className="flex items-center gap-3">
{/* Last updated */}
{dataUpdatedAt > 0 && (
<span className="text-xs text-text-secondary flex items-center gap-1">
<Clock size={12} />
Updated {new Date(dataUpdatedAt).toLocaleTimeString()}
</span>
)}
{/* Refresh interval selector */}
<div className="flex items-center gap-1 bg-bg-card/60 border border-primary/15 rounded-xl p-1">
<RefreshCw size={14} className="text-text-secondary ml-2" />
{REFRESH_OPTIONS.map((opt) => (
<button
key={opt.value}
onClick={() => setRefreshInterval(opt.value)}
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
refreshInterval === opt.value
? 'bg-primary text-white'
: 'text-text-secondary hover:text-text-primary hover:bg-bg-card-hover'
}`}
>
{opt.label}
</button>
))}
</div>
</div>
</div>
{/* Summary cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<Activity size={18} className="text-primary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-error/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-error" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Throttled (24h)
</span>
</div>
<div className="text-3xl font-bold text-error">
{data?.throttled_24h ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-secondary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-secondary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rate-Limited Endpoints
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.length ?? 0}
</div>
</div>
</div>
{/* Throttled over time chart (simple bar chart) */}
{data?.throttled_over_time && data.throttled_over_time.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider mb-4">
Throttled Requests Over Time (Last 24h)
</h3>
<div className="flex items-end gap-1 h-32">
{data.throttled_over_time.map((bucket) => {
const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0;
const hour = new Date(bucket.timestamp).getHours();
return (
<div key={bucket.timestamp} className="flex-1 flex flex-col items-center gap-1">
<span className="text-xs text-text-secondary">{bucket.count}</span>
<div
className="w-full bg-error/70 rounded-t-sm min-h-[2px] transition-all"
style={{ height: `${Math.max(height, 2)}%` }}
title={`${bucket.timestamp}: ${bucket.count} throttled`}
/>
<span className="text-[10px] text-text-secondary">{hour}:00</span>
</div>
);
})}
</div>
</div>
)}
{/* Per-endpoint table */}
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Limit
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data?.rate_limits.map((rl) => (
<tr key={rl.endpoint} className="hover:bg-bg-card-hover/50 transition-colors">
<td className="px-6 py-4 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-4">
<span className="inline-flex px-2 py-0.5 rounded-full text-xs font-medium bg-primary/10 text-primary border border-primary/20">
{rl.limit}
</span>
</td>
<td className="px-6 py-4 text-right text-text-primary font-semibold">
{rl.total_requests}
</td>
<td className="px-6 py-4 text-right">
<span className={rl.rejected_requests > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{rl.rejected_requests}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Per-IP breakdown */}
{data?.rate_limits.some((rl) => rl.by_ip.length > 0) && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="px-6 py-4 border-b border-primary/10">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Per-IP Breakdown
</h3>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
IP Address
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data.rate_limits.flatMap((rl) =>
rl.by_ip.map((ipEntry) => (
<tr
key={`${rl.endpoint}-${ipEntry.ip}`}
className="hover:bg-bg-card-hover/50 transition-colors"
>
<td className="px-6 py-3 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-3 font-mono text-sm text-text-secondary">{ipEntry.ip}</td>
<td className="px-6 py-3 text-right text-text-primary">{ipEntry.total}</td>
<td className="px-6 py-3 text-right">
<span className={ipEntry.rejected > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{ipEntry.rejected}
</span>
</td>
</tr>
))
)}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}
+132
View File
@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Migration: add owner_id columns and backfill existing rows.
This script adds an ``owner_id`` column (FK to ``users``) to the
``llm_messages``, ``jobs``, and ``tracked_companies`` tables, then
backfills all existing rows with ``owner_id = 1`` (the default admin user).
It also replaces the old global UNIQUE constraint on
``tracked_companies.company_name`` with a per-owner unique index so that
different users can independently track the same company.
Usage:
python scripts/migrate_add_owner_id.py
The script is idempotent — running it multiple times is safe.
"""
import os
import sys
import psycopg2
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/sparc",
)
DEFAULT_OWNER_ID = 1
def run_migration():
"""Execute the migration."""
conn = psycopg2.connect(DATABASE_URL)
conn.autocommit = False
try:
with conn.cursor() as cur:
# ---------- 1. Add owner_id columns if missing ----------
cur.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# ---------- 2. Backfill owner_id = DEFAULT_OWNER_ID ----------
cur.execute(
"UPDATE llm_messages SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
messages_updated = cur.rowcount
print(f" llm_messages: backfilled {messages_updated} rows")
cur.execute(
"UPDATE jobs SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
jobs_updated = cur.rowcount
print(f" jobs: backfilled {jobs_updated} rows")
cur.execute(
"UPDATE tracked_companies SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
tracked_updated = cur.rowcount
print(f" tracked_companies: backfilled {tracked_updated} rows")
# ---------- 3. Create indexes ----------
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# ---------- 4. Replace unique constraint on tracked_companies ----------
cur.execute("""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cur.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
conn.commit()
print("Migration completed successfully.")
except Exception:
conn.rollback()
print("Migration FAILED — rolled back.", file=sys.stderr)
raise
finally:
conn.close()
if __name__ == "__main__":
print(f"Running owner_id migration against {DATABASE_URL.split('@')[-1]} ...")
run_migration()
+74 -18
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints.""" """Tests for FastAPI web service endpoints."""
from datetime import datetime from datetime import datetime, timezone
from unittest.mock import Mock from unittest.mock import Mock, MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from SPARC.api import app from SPARC.api import app
from SPARC.auth import create_access_token
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -16,6 +17,22 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
@pytest.fixture @pytest.fixture
def mock_analyzer(mocker): def mock_analyzer(mocker):
"""Mock the global analyzer.""" """Mock the global analyzer."""
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
return mock return mock
def _auth_header(user_id=1, email="user@test.com", role="user"):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_id, email, role)
return {"Authorization": f"Bearer {token}"}
class TestHealthEndpoint: class TestHealthEndpoint:
"""Test health check endpoint.""" """Test health check endpoint."""
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia") response = client.get("/analyze/nvidia", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown") response = client.get("/analyze/unknown", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2}, json={"companies": ["nvidia", "amd"], "max_workers": 2},
headers=_auth_header(),
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
def test_batch_analysis_validation(self, client): def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation.""" """Test batch analysis request validation."""
# Empty companies list # Empty companies list
response = client.post("/analyze/batch", json={"companies": []}) response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
assert response.status_code == 422 assert response.status_code == 422
# Too many companies # Too many companies
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]}, json={"companies": [f"company{i}" for i in range(25)]},
headers=_auth_header(),
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10}, json={"companies": ["nvidia"], "max_workers": 10},
headers=_auth_header(),
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
class TestAsyncBatchEndpoint: class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint.""" """Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer): @patch("SPARC.api._get_job_db")
"""Test async endpoint creates a job.""" def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
"""Test async endpoint creates a job with owner_id."""
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j1",
"status": "pending",
"progress": 0,
"total_companies": 2,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": 1,
}
mock_get_db.return_value = job_db
response = client.post( response = client.post(
"/analyze/batch/async", "/analyze/batch/async",
json={"companies": ["nvidia", "amd"]}, json={"companies": ["nvidia", "amd"]},
headers=_auth_header(),
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
assert data["status"] == "pending" assert data["status"] == "pending"
assert data["total_companies"] == 2 assert data["total_companies"] == 2
assert data["progress"] == 0 assert data["progress"] == 0
# Verify owner_id was passed
job_db.create_job.assert_called_once()
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
class TestJobEndpoints: class TestJobEndpoints:
"""Test job management endpoints.""" """Test job management endpoints."""
def test_get_job_not_found(self, client): @patch("SPARC.api._get_job_db")
def test_get_job_not_found(self, mock_get_db, client):
"""Test getting nonexistent job.""" """Test getting nonexistent job."""
response = client.get("/jobs/nonexistent") job_db = MagicMock()
job_db.get_job.return_value = None
mock_get_db.return_value = job_db
response = client.get("/jobs/nonexistent", headers=_auth_header())
assert response.status_code == 404 assert response.status_code == 404
def test_list_jobs(self, client, mocker): @patch("SPARC.api._get_job_db")
def test_list_jobs(self, mock_get_db, client):
"""Test listing jobs.""" """Test listing jobs."""
# Clear existing jobs job_db = MagicMock()
mocker.patch.dict("SPARC.api._jobs", {}, clear=True) job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs") response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker): @patch("SPARC.api._get_job_db")
def test_list_jobs_with_filter(self, mock_get_db, client):
"""Test listing jobs with status filter.""" """Test listing jobs with status filter."""
response = client.get("/jobs?status=completed") job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs?status=completed", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
@@ -189,7 +244,7 @@ class TestModelValidation:
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400.""" """GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model") response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
assert response.status_code == 400 assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"] assert "Unsupported model" in response.json()["detail"]
@@ -205,7 +260,7 @@ class TestModelValidation:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet") response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer): def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
@@ -213,6 +268,7 @@ class TestModelValidation:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
headers=_auth_header(),
) )
assert response.status_code == 400 assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"] assert "Unsupported model" in response.json()["detail"]
+1
View File
@@ -5,6 +5,7 @@ Covers issue #1655:
- GET /export/{company_name}/pdf (PDF export) - GET /export/{company_name}/pdf (PDF export)
All tests mock the database layer and use JWT auth fixtures from test_auth patterns. All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
Export queries are now scoped to the current user's owner_id.
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
+281
View File
@@ -0,0 +1,281 @@
"""Cross-tenant isolation tests for multi-tenant support.
Verifies that:
- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs
- Admin users can access all data via admin endpoints
- owner_id is correctly set on new resources
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user(user_id, email, role="user"):
return {
"id": user_id,
"email": email,
"role": role,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
USER_A = _make_user(10, "alice@test.com")
USER_B = _make_user(20, "bob@test.com")
ADMIN = _make_user(1, "admin@test.com", role="admin")
def _header_for(user):
token = create_access_token(user["id"], user["email"], user["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db():
"""Mock DB returning the correct user based on user_id."""
db = MagicMock()
def _get_user_by_id(uid):
for u in [USER_A, USER_B, ADMIN]:
if u["id"] == uid:
return u
return None
db.get_user_by_id.side_effect = _get_user_by_id
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
# ==================== Tracked Companies Isolation ====================
class TestTrackedCompanyIsolation:
"""User A's tracked companies are invisible to User B."""
def test_user_a_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User A's companies."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": USER_A["id"]},
]
response = client.get("/tracked", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"])
def test_user_b_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User B's companies."""
mock_db.list_tracked_companies.return_value = []
response = client.get("/tracked", headers=_header_for(USER_B))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"])
def test_user_a_add_sets_owner(self, client, mock_db):
"""POST /tracked sets owner_id to User A."""
mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10}
response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"])
def test_user_b_cannot_remove_user_a_company(self, client, mock_db):
"""DELETE /tracked/{name} filters by owner, so B can't remove A's company."""
mock_db.remove_tracked_company.return_value = False # not found for B
response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B))
assert response.status_code == 404
mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"])
# ==================== Job Isolation ====================
class TestJobIsolation:
"""User A's jobs are invisible to User B."""
def test_user_a_get_own_job(self, client, mock_db):
"""GET /jobs/{id} scoped to User A returns the job."""
mock_db.get_job.return_value = None # mock via _get_job_db
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = {
"job_id": "j1",
"status": "completed",
"progress": 100,
"total_companies": 1,
"completed_companies": 1,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_A))
assert response.status_code == 200
job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"])
def test_user_b_cannot_see_user_a_job(self, client, mock_db):
"""GET /jobs/{id} returns 404 when User B tries to access User A's job."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = None # not found for B's owner_id
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_B))
assert response.status_code == 404
job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"])
def test_list_jobs_scoped_to_user(self, client, mock_db):
"""GET /jobs filters by owner_id."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
def test_async_job_created_with_owner(self, client, mock_db):
"""POST /analyze/batch/async creates job with current user's owner_id."""
mock_analyzer = MagicMock()
with patch("SPARC.api._analyzer", mock_analyzer), \
patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j2",
"status": "pending",
"progress": 0,
"total_companies": 1,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia"]},
headers=_header_for(USER_A),
)
assert response.status_code == 200
create_kwargs = job_db.create_job.call_args
assert create_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Analysis Listing Isolation ====================
class TestAnalysisListIsolation:
"""GET /analyze/batch scoped to current user."""
def test_list_analyses_scoped_to_user(self, client, mock_db):
"""GET /analyze/batch passes owner_id to db.list_analyses."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/analyze/batch", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Admin Cross-Tenant Access ====================
class TestAdminCrossTenantAccess:
"""Admin endpoints return data from all tenants (no owner_id filter)."""
def test_admin_list_tracked_all_tenants(self, client, mock_db):
"""GET /admin/tracked returns all companies (no owner_id filter)."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": 10},
{"company_name": "BobCo", "owner_id": 20},
]
response = client.get("/admin/tracked", headers=_header_for(ADMIN))
assert response.status_code == 200
# Should be called without owner_id filter
mock_db.list_tracked_companies.assert_called_with()
def test_admin_list_analyses_all_tenants(self, client, mock_db):
"""GET /admin/analyses returns all analyses (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/analyses", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
# No owner_id should be passed
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_list_jobs_all_tenants(self, client, mock_db):
"""GET /admin/jobs returns all jobs (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/jobs", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_remove_tracked_any_owner(self, client, mock_db):
"""DELETE /admin/tracked/{name} removes without owner filter."""
mock_db.remove_tracked_company.return_value = True
response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN))
assert response.status_code == 200
# Called without owner_id
mock_db.remove_tracked_company.assert_called_with("SomeCo")
def test_regular_user_cannot_access_admin_analyses(self, client, mock_db):
"""Regular user gets 403 on /admin/analyses."""
response = client.get("/admin/analyses", headers=_header_for(USER_A))
assert response.status_code == 403
def test_regular_user_cannot_access_admin_jobs(self, client, mock_db):
"""Regular user gets 403 on /admin/jobs."""
response = client.get("/admin/jobs", headers=_header_for(USER_A))
assert response.status_code == 403
# ==================== Analytics Isolation ====================
class TestAnalyticsIsolation:
"""GET /analytics scoped to current user."""
def test_analytics_scoped_to_user(self, client, mock_db):
"""GET /analytics passes owner_id to db.get_analytics."""
mock_db.get_analytics.return_value = {
"total_messages": 5,
"by_company": [],
"by_type": [],
"period_days": 30,
}
response = client.get("/analytics", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"])
+35 -13
View File
@@ -1,12 +1,13 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints.""" """Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch from unittest.mock import Mock, MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from SPARC.api import app from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture @pytest.fixture
@@ -15,6 +16,27 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"): def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict.""" """Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago) ts = datetime.now() - timedelta(minutes=minutes_ago)
@@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination:
] ]
mock_get_db.return_value = db mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10") response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["items"]) == 2 assert len(data["items"]) == 2
@@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = rows db.list_analyses.return_value = rows
mock_get_db.return_value = db mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3") response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["items"]) == 3 assert len(data["items"]) == 3
@@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = [] db.list_analyses.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42") client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once() db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \ assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
@@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = [] db.list_analyses.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
client.get("/analyze/batch") client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51 # The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client): def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422.""" """Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201") response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422 assert response.status_code == 422
def test_limit_zero_rejected(self, client): def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422.""" """Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0") response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422 assert response.status_code == 422
@patch("SPARC.api._get_job_db") @patch("SPARC.api._get_job_db")
@@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = [] db.list_analyses.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel") client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \ assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else []) "intel" in (call_kwargs.args if call_kwargs.args else [])
@@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = [] db.list_analyses.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
response = client.get("/analyze/batch") response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["items"] == [] assert data["items"] == []
@@ -148,14 +170,14 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = [] db.list_jobs.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
client.get("/jobs") client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51 # Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client): def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422.""" """Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201") response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422 assert response.status_code == 422
@patch("SPARC.api._get_job_db") @patch("SPARC.api._get_job_db")
@@ -165,5 +187,5 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = [] db.list_jobs.return_value = []
mock_get_db.return_value = db mock_get_db.return_value = db
response = client.get("/jobs?limit=200") response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
+71 -2
View File
@@ -20,8 +20,10 @@ def client():
def reset_stats(): def reset_stats():
"""Reset rate limit stats between tests.""" """Reset rate limit stats between tests."""
api._rate_limit_stats.clear() api._rate_limit_stats.clear()
api._rejected_log.clear()
yield yield
api._rate_limit_stats.clear() api._rate_limit_stats.clear()
api._rejected_log.clear()
def _mock_admin(): def _mock_admin():
@@ -50,8 +52,7 @@ class TestRateLimitAdminEndpoint:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_non_admin_rejected(self, client): def test_non_admin_rejected(self, client):
"""Non-admin users should get 403.""" """Non-admin users should get 401/403."""
# Without overriding the dependency, it should fail auth
response = client.get("/admin/rate-limits") response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403) assert response.status_code in (401, 403)
@@ -77,6 +78,9 @@ class TestRateLimitAdminEndpoint:
for rl in data["rate_limits"]: for rl in data["rate_limits"]:
assert rl["total_requests"] == 0 assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0 assert rl["rejected_requests"] == 0
assert rl["by_ip"] == []
assert data["throttled_24h"] == 0
assert data["throttled_over_time"] == []
finally: finally:
app.dependency_overrides.clear() app.dependency_overrides.clear()
@@ -107,3 +111,68 @@ class TestRateLimitAdminEndpoint:
assert isinstance(rl["limit"], str) assert isinstance(rl["limit"], str)
finally: finally:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_per_ip_breakdown(self, client):
"""Stats should include per-IP breakdown with total and rejected counts."""
api._track_rate_limit_request("/auth/login", "10.0.0.1")
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/login", "10.0.0.2")
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
by_ip = login_stats["by_ip"]
assert len(by_ip) == 2
ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1")
assert ip1["total"] == 2
assert ip1["rejected"] == 1
ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2")
assert ip2["total"] == 1
assert ip2["rejected"] == 0
finally:
app.dependency_overrides.clear()
def test_throttled_24h_count(self, client):
"""Should report total throttled requests in the last 24 hours."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert data["throttled_24h"] == 2
finally:
app.dependency_overrides.clear()
def test_throttled_over_time_structure(self, client):
"""Throttled-over-time should be a list of {timestamp, count} buckets."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert len(data["throttled_over_time"]) >= 1
entry = data["throttled_over_time"][0]
assert "timestamp" in entry
assert "count" in entry
assert entry["count"] >= 1
finally:
app.dependency_overrides.clear()
def test_response_shape_matches_contract(self, client):
"""The full response should match the expected shape for the frontend."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
# Top-level keys
assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"}
# Each rate_limit entry
for rl in data["rate_limits"]:
assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"}
finally:
app.dependency_overrides.clear()
+71 -10
View File
@@ -1,17 +1,18 @@
"""Tests for tracked company admin endpoints and scheduler integration. """Tests for tracked company endpoints and scheduler integration.
Covers issue #1656: Covers:
- GET /admin/tracked (list tracked companies) - GET /tracked (user-scoped list)
- POST /admin/tracked (add a tracked company) - POST /tracked (user-scoped add)
- DELETE /admin/tracked/{company_name} (remove a tracked company) - DELETE /tracked/{company_name} (user-scoped remove)
- GET /admin/tracked (admin: all companies)
- POST /admin/tracked (admin: add)
- DELETE /admin/tracked/{company_name} (admin: remove any)
- GET /admin/alerts (list alerts) - GET /admin/alerts (list alerts)
- scheduler.run_scheduled_analysis() integration - scheduler.run_scheduled_analysis() integration
All tests mock the database layer and use JWT auth fixtures.
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import MagicMock, patch, call from unittest.mock import MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -125,7 +126,7 @@ class TestAddTrackedCompany:
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["company_name"] == "Intel" assert data["company_name"] == "Intel"
mock_db.add_tracked_company.assert_called_once_with("Intel") mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1)
def test_add_duplicate_returns_409(self, client, mock_db): def test_add_duplicate_returns_409(self, client, mock_db):
"""Adding an already-tracked company returns 409.""" """Adding an already-tracked company returns 409."""
@@ -141,7 +142,7 @@ class TestAddTrackedCompany:
assert "already tracked" in response.json()["detail"].lower() assert "already tracked" in response.json()["detail"].lower()
def test_add_tracked_requires_admin(self, client, mock_db): def test_add_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot add tracked companies.""" """Regular user cannot add tracked companies via admin endpoint."""
mock_db.get_user_by_id.return_value = { mock_db.get_user_by_id.return_value = {
"id": 2, "id": 2,
"email": "user@test.com", "email": "user@test.com",
@@ -215,6 +216,66 @@ class TestRemoveTrackedCompany:
assert response.status_code == 403 assert response.status_code == 403
# ---------- User-scoped tracked companies ----------
class TestUserScopedTrackedCompanies:
"""Tests for /tracked user-scoped endpoints."""
def test_user_list_tracked(self, client, mock_db):
"""Regular user can list their own tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.list_tracked_companies.return_value = [
{"company_name": "AMD", "owner_id": 2},
]
response = client.get("/tracked", headers=_user_header())
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=2)
def test_user_add_tracked(self, client, mock_db):
"""Regular user can add a company to their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.add_tracked_company.return_value = {
"company_name": "Intel",
"owner_id": 2,
}
response = client.post(
"/tracked",
json={"company_name": "Intel"},
headers=_user_header(),
)
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2)
def test_user_remove_tracked(self, client, mock_db):
"""Regular user can remove a company from their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.remove_tracked_company.return_value = True
response = client.delete("/tracked/Intel", headers=_user_header())
assert response.status_code == 200
mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2)
# ---------- GET /admin/alerts ---------- # ---------- GET /admin/alerts ----------
class TestListAlerts: class TestListAlerts: