Compare commits

..

9 Commits

Author SHA1 Message Date
agent-company e37859dabc Add multi-tenant support with owner_id isolation
- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 16:04:58 +00:00
agent-company 3dfa651f2d Add rate limiting dashboard to admin panel
- Enhance GET /admin/rate-limits with per-IP breakdown, 24h throttled
  count, and hourly time-series of rejected requests
- Add _rejected_log deque for time-series tracking of throttled requests
- Add AdminRateLimits React page with auto-refresh (configurable 15s/30s/1m),
  summary cards, throttled-over-time bar chart, endpoint table, per-IP table
- Add TypeScript types (RateLimitStatsResponse) and adminApi.getRateLimits()
- Wire up /admin/rate-limits route and nav link (admin-only)
- Expand unit tests to 10 cases: auth, empty state, per-IP breakdown,
  throttled_24h count, time-series structure, response shape contract

Closes leeworks-agents/SPARC#1686

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:39:45 +00:00
AI-Manager 313800215c Merge pull request 'Add rate limit stats to admin panel' (#1682) from feature/1675-rate-limit-admin into main
Merge PR #1682
2026-05-19 00:12:56 +00:00
AI-Manager 222f29deb1 Merge pull request 'Add cursor-based pagination to /analyze/batch and /jobs' (#1681) from feature/1669-cursor-pagination into main
Merge PR #1681
2026-05-19 00:12:48 +00:00
AI-Manager e6d95bbf57 Merge pull request 'Add stricter input validation for company names' (#1680) from feature/1670-company-name-validation into main
Merge PR #1680
2026-05-19 00:12:42 +00:00
AI-Manager 68484ef4b1 Merge pull request 'Update ROADMAP.md: mark completed P1 and P2 items as done' (#1679) from feature/1678-update-roadmap into main
Merge PR #1679
2026-05-19 00:12:34 +00:00
agent-company a0cb9a5773 Add rate limit status and usage statistics to admin panel
Add GET /admin/rate-limits endpoint (admin-only) that returns current
rate limit configuration and request statistics for all rate-limited
endpoints (/auth/register and /auth/login). Tracks total requests and
rejection counts via in-memory counters.

Includes tests for admin access, non-admin rejection, empty state,
request tracking, and configuration display.

Closes leeworks-agents/SPARC#1675

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:53:01 +00:00
agent-company 857b3444df Add cursor-based pagination to GET /analyze/batch and update /jobs defaults
Add a new GET /analyze/batch endpoint that returns stored analysis results
with cursor-based pagination (default limit 50, max 200). Also update the
existing /jobs endpoint defaults from limit=10/max=100 to limit=50/max=200
for consistency.

The database layer gains a list_analyses() method with cursor support using
(timestamp, id) ordering, matching the existing list_jobs() pattern.

Includes tests for pagination behavior, boundary limits, cursor forwarding,
company name filtering, and empty result sets.

Closes leeworks-agents/SPARC#1669

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:49:22 +00:00
agent-company 7c6eed8d72 Update ROADMAP.md to mark completed P1 and P2 items as done
Move seven completed items from the P1 and P2 sections into the
Completed section: in-memory jobs persistence, export endpoint tests,
tracked company admin tests, webhook integration tests, S3 storage
tests, auto-download path tests, and scheduler DatabaseClient refactor.

The P2 section now only lists the two genuinely open items: cursor-based
pagination (Issue #1669) and request validation (Issue #1670).

Closes leeworks-agents/SPARC#1678

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:29:14 +00:00
14 changed files with 1792 additions and 158 deletions
+30 -37
View File
@@ -81,57 +81,50 @@ Items that have been implemented and merged into main.
- ~~OpenAPI client generation.~~ TypeScript API client auto-generated from - ~~OpenAPI client generation.~~ TypeScript API client auto-generated from
FastAPI spec with CI freshness check. FastAPI spec with CI freshness check.
### Resilience
- ~~`_jobs` dict is in-memory only.~~ Database-backed job persistence
implemented using `db.list_jobs()` and `mark_stale_jobs_failed()`. The
in-memory `_jobs` dict has been removed.
### Test coverage (P1/P2)
- ~~Export endpoint tests.~~ Tests added for CSV and PDF export endpoints.
- ~~Tracked company admin endpoint tests.~~ Tests added for `/admin/tracked`
CRUD endpoints and scheduler integration.
- ~~Webhook integration tests.~~ Tests added for retry logic, Slack/Discord
payload format, and multi-URL dispatch.
- ~~S3/MinIO storage backend tests.~~ Unit tests added for the S3 backend
(read, write, exists, delete, error handling).
- ~~`analyze_single_patent` auto-download path tests.~~ Tests added for the
auto-download fallback (cache lookup, PDF download, FileNotFoundError).
### Code quality
- ~~Scheduler creates its own DatabaseClient.~~ Refactored to use the
application-level pooled `get_db_client()`.
--- ---
## P1 -- High Priority ## P1 -- High Priority
These items address correctness, reliability, and coverage gaps that should be No outstanding P1 items. All previously listed items have been completed and
resolved before broader production use. moved to the Completed section above.
### Resilience
- **`_jobs` dict is in-memory only.** Job state is lost on API restart.
Persist job status in PostgreSQL or Redis so async batch results survive
restarts.
### Test coverage gaps
- **Export endpoint tests.** The CSV and PDF export endpoints (`/export/`)
lack test coverage. Add tests covering auth, success, 404, and edge cases.
*(Issue #1655)*
- **Tracked company admin endpoint tests.** The `/admin/tracked` CRUD
endpoints and scheduler integration lack test coverage. *(Issue #1656)*
--- ---
## P2 -- Medium Priority ## P2 -- Medium Priority
Improvements to reliability, test coverage, and code quality. Improvements to the API surface.
### Test coverage
- **Webhook integration tests.** The retry logic, Slack/Discord payload
format, and multi-URL dispatch in `webhooks.py` need test coverage.
*(Issue #1657)*
- **S3/MinIO storage backend tests.** `storage.py` has local filesystem tests
but no unit tests for the S3 backend (read, write, exists, delete,
error handling). *(Issue #1660)*
- **`analyze_single_patent` auto-download path tests.** The auto-download
fallback (cache lookup, PDF download, FileNotFoundError) in
`analyzer.py` lacks test coverage. *(Issue #1661)*
### Code quality
- **Scheduler creates its own DatabaseClient.** `scheduler.py` bypasses the
application-level pooled client, creating a new connection on every tick.
Refactor to use `get_db_client()`. *(Issue #1658)*
### API improvements ### API improvements
- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit - **API pagination.** The `/analyze/batch` endpoint needs cursor-based
from cursor-based pagination for large result sets. pagination for large result sets. The `/jobs` endpoint already has cursor
pagination. *(Issue #1669)*
- **Request validation improvements.** Add stricter input validation for - **Request validation improvements.** Add stricter input validation for
company names (disallow special characters, enforce length limits). company names (disallow special characters, enforce length limits).
*(Issue #1670)*
--- ---
+345 -63
View File
@@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -106,6 +107,24 @@ class JobStatus(BaseModel):
error: str | None = None error: str | None = None
class AnalysisRecord(BaseModel):
"""A single stored analysis result."""
id: int
company_name: str | None = None
analysis_type: str | None = None
model: str | None = None
response: str | None = None
timestamp: datetime | None = None
class PaginatedAnalysisResponse(BaseModel):
"""Paginated response for analysis result listings."""
items: list[AnalysisRecord]
next_cursor: str | None = None
class PaginatedJobsResponse(BaseModel): class PaginatedJobsResponse(BaseModel):
"""Paginated response for job listings.""" """Paginated response for job listings."""
@@ -227,10 +246,45 @@ app = FastAPI(
limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
key = endpoint
if key not in _rate_limit_stats:
_rate_limit_stats[key] = {
"endpoint": endpoint,
"total_requests": 0,
"rejected_requests": 0,
"by_ip": {},
}
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
ip_stats[ip]["total"] += 1
if rejected:
ip_stats[ip]["rejected"] += 1
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded): async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded.""" """Return 429 with Retry-After header when rate limit is exceeded."""
endpoint = request.url.path
ip = get_remote_address(request)
_track_rate_limit_request(endpoint, ip, rejected=True)
retry_after = getattr(exc, "retry_after", 60) retry_after = getattr(exc, "retry_after", 60)
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
@@ -259,6 +313,7 @@ async def register(request: Request, body: RegisterRequest):
The first registered user automatically becomes an admin. The first registered user automatically becomes an admin.
""" """
_track_rate_limit_request("/auth/register", get_remote_address(request))
db = get_db_client() db = get_db_client()
# First user becomes admin # First user becomes admin
@@ -289,6 +344,7 @@ async def register(request: Request, body: RegisterRequest):
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest): async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens.""" """Authenticate user and return JWT tokens."""
_track_rate_limit_request("/auth/login", get_remote_address(request))
db = get_db_client() db = get_db_client()
user = db.authenticate_user(body.email, body.password) user = db.authenticate_user(body.email, body.password)
@@ -418,11 +474,46 @@ class TrackCompanyRequest(BaseModel):
company_name: CompanyName = Field(...) company_name: CompanyName = Field(...)
@app.get("/tracked", tags=["Tracked Companies"])
async def list_my_tracked_companies(
current_user: UserResponse = Depends(get_current_user),
):
"""List tracked companies for the current user."""
db = get_db_client()
return db.list_tracked_companies(owner_id=current_user.id)
@app.post("/tracked", tags=["Tracked Companies"])
async def add_my_tracked_company(
request: TrackCompanyRequest,
current_user: UserResponse = Depends(get_current_user),
):
"""Add a company to the current user's tracked list."""
db = get_db_client()
result = db.add_tracked_company(request.company_name, owner_id=current_user.id)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/tracked/{company_name}", tags=["Tracked Companies"])
async def remove_my_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
current_user: UserResponse = Depends(get_current_user),
):
"""Remove a company from the current user's tracked list."""
db = get_db_client()
removed = db.remove_tracked_company(company_name, owner_id=current_user.id)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/tracked", tags=["Admin"]) @app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies( async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""List all tracked companies (admin only).""" """List all tracked companies across all users (admin only)."""
db = get_db_client() db = get_db_client()
return db.list_tracked_companies() return db.list_tracked_companies()
@@ -430,11 +521,11 @@ async def list_tracked_companies(
@app.post("/admin/tracked", tags=["Admin"]) @app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company( async def add_tracked_company(
request: TrackCompanyRequest, request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin), current_admin: UserResponse = Depends(get_current_admin),
): ):
"""Add a company to the tracked list (admin only).""" """Add a company to the tracked list (admin only, owned by admin)."""
db = get_db_client() db = get_db_client()
result = db.add_tracked_company(request.company_name) result = db.add_tracked_company(request.company_name, owner_id=current_admin.id)
if not result: if not result:
raise HTTPException(status_code=409, detail="Company already tracked") raise HTTPException(status_code=409, detail="Company already tracked")
return result return result
@@ -445,7 +536,7 @@ async def remove_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""Remove a company from the tracked list (admin only).""" """Remove a company from the tracked list (admin only, any owner)."""
db = get_db_client() db = get_db_client()
removed = db.remove_tracked_company(company_name) removed = db.remove_tracked_company(company_name)
if not removed: if not removed:
@@ -453,6 +544,69 @@ async def remove_tracked_company(
return {"message": f"Stopped tracking {company_name}"} return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/rate-limits", tags=["Admin"])
async def get_rate_limit_stats(
_: UserResponse = Depends(get_current_admin),
):
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns:
Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
"/auth/login": {"limit": "10/minute"},
}
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
})
# Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"]) @app.get("/admin/alerts", tags=["Admin"])
async def list_alerts( async def list_alerts(
limit: int = Query(default=50, ge=1, le=200), limit: int = Query(default=50, ge=1, le=200),
@@ -463,17 +617,86 @@ async def list_alerts(
return db.list_alerts(limit=limit) return db.list_alerts(limit=limit)
# ============== Admin-Scoped Data Endpoints ==============
@app.get("/admin/analyses", response_model=PaginatedAnalysisResponse, tags=["Admin"])
async def admin_list_analyses(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all analysis results across all users (admin only)."""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.get("/admin/jobs", response_model=PaginatedJobsResponse, tags=["Admin"])
async def admin_list_jobs(
status: Annotated[
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all jobs across all users (admin only)."""
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
has_next = len(job_rows) > limit
if has_next:
job_rows = job_rows[:limit]
items = [_job_row_to_status(row) for row in job_rows]
next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
# ============== Analytics Endpoint ============== # ============== Analytics Endpoint ==============
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"]) @app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
async def get_analytics( async def get_analytics(
days: int = Query(default=30, ge=1, le=365), days: int = Query(default=30, ge=1, le=365),
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get analytics data (authenticated users only).""" """Get analytics data scoped to the current user."""
db = get_db_client() db = get_db_client()
analytics = db.get_analytics(days=days) analytics = db.get_analytics(days=days, owner_id=current_user.id)
return AnalyticsResponse( return AnalyticsResponse(
total_messages=analytics["total_messages"], total_messages=analytics["total_messages"],
@@ -526,9 +749,9 @@ async def list_models():
@app.get("/analytics/trends", tags=["Analytics"]) @app.get("/analytics/trends", tags=["Analytics"])
async def get_analytics_trends( async def get_analytics_trends(
days: int = Query(default=90, ge=7, le=365), days: int = Query(default=90, ge=7, le=365),
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get trend data for patent analysis over time. """Get trend data for patent analysis over time (scoped to current user).
Returns two datasets: Returns two datasets:
- ``by_month``: analysis count per company per month - ``by_month``: analysis count per company per month
@@ -542,11 +765,14 @@ async def get_analytics_trends(
""" """
db = get_db_client() db = get_db_client()
owner_filter = " AND owner_id = %s" if current_user else ""
owner_params = (current_user.id,) if current_user else ()
with db.get_conn() as conn: with db.get_conn() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
# Analyses per company per month # Analyses per company per month
cur.execute( cur.execute(
""" f"""
SELECT SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month, TO_CHAR(timestamp, 'YYYY-MM') AS month,
company_name, company_name,
@@ -555,16 +781,17 @@ async def get_analytics_trends(
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE AND is_cached = FALSE
AND company_name IS NOT NULL AND company_name IS NOT NULL
{owner_filter}
GROUP BY month, company_name GROUP BY month, company_name
ORDER BY month ORDER BY month
""", """,
(days,), (days, *owner_params),
) )
by_month_rows = cur.fetchall() by_month_rows = cur.fetchall()
# Analysis type distribution per month # Analysis type distribution per month
cur.execute( cur.execute(
""" f"""
SELECT SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month, TO_CHAR(timestamp, 'YYYY-MM') AS month,
analysis_type, analysis_type,
@@ -572,10 +799,11 @@ async def get_analytics_trends(
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE AND is_cached = FALSE
{owner_filter}
GROUP BY month, analysis_type GROUP BY month, analysis_type
ORDER BY month ORDER BY month
""", """,
(days,), (days, *owner_params),
) )
by_type_rows = cur.fetchall() by_type_rows = cur.fetchall()
@@ -601,9 +829,9 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"]) @app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv( async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a CSV file. """Export analysis results for a company as a CSV file (scoped to current user).
Returns all stored analysis records for the given company, including Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp. analysis type, model used, response text, and timestamp.
@@ -618,7 +846,7 @@ async def export_company_csv(
import io import io
db = get_db_client() db = get_db_client()
# Query all non-cached analysis results for this company # Query all non-cached analysis results for this company owned by current user
with db.get_conn() as conn: with db.get_conn() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute( cur.execute(
@@ -626,9 +854,10 @@ async def export_company_csv(
SELECT company_name, analysis_type, model, response, timestamp SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC ORDER BY timestamp DESC
""", """,
(company_name,), (company_name, current_user.id),
) )
rows = cur.fetchall() rows = cur.fetchall()
@@ -653,9 +882,9 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"]) @app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf( async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a formatted PDF report. """Export analysis results for a company as a formatted PDF report (scoped to current user).
Returns all stored analysis records for the given company, including Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted analysis type, model used, response text, and timestamp, formatted
@@ -689,9 +918,10 @@ async def export_company_pdf(
SELECT company_name, analysis_type, model, response, timestamp SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC ORDER BY timestamp DESC
""", """,
(company_name,), (company_name, current_user.id),
) )
rows = cur.fetchall() rows = cur.fetchall()
@@ -821,33 +1051,87 @@ async def health_check():
@app.get( @app.get(
"/analyze/{company_name}", "/analyze/batch",
response_model=CompanyAnalysisResponse, response_model=PaginatedAnalysisResponse,
tags=["Analysis"], tags=["Analysis"],
) )
async def analyze_company( async def list_analysis_results(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), str | None,
_: UserResponse = Depends(get_current_user), Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
current_user: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single company's patent portfolio. """List stored analysis results with cursor-based pagination (scoped to current user).
This endpoint retrieves recent patents for the specified company, Returns past analysis results ordered by timestamp descending. Use
parses them, and uses AI to generate a comprehensive analysis. ``limit`` to control page size (default 50, max 200). The response
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
parameter to fetch the next page. When ``next_cursor`` is ``null``,
there are no more results.
Args: Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel") company_name: Optional filter by company name
model: Optional LLM model override limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response
Returns: Returns:
Analysis results including patent count, AI insights, and success status Paginated list of analysis results
""" """
_validate_model(model) db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
tags=["Analysis"],
)
async def analyze_companies_batch(
request: BatchAnalysisRequest,
_: UserResponse = Depends(get_current_user),
):
"""Analyze multiple companies' patent portfolios.
Processes companies concurrently for improved performance.
Limited to 20 companies per request.
Args:
request: List of company names and optional worker count
Returns:
Batch results with individual company analyses and summary statistics
"""
_validate_model(request.model)
if not _analyzer: if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized") raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model) result = _analyzer.analyze_companies(
return _convert_result(result) companies=request.companies,
max_workers=request.max_workers,
model=request.model,
)
return _convert_batch_result(result)
@app.get( @app.get(
@@ -882,36 +1166,34 @@ async def analyze_single_patent(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@app.post( @app.get(
"/analyze/batch", "/analyze/{company_name}",
response_model=BatchAnalysisResponse, response_model=CompanyAnalysisResponse,
tags=["Analysis"], tags=["Analysis"],
) )
async def analyze_companies_batch( async def analyze_company(
request: BatchAnalysisRequest, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze multiple companies' patent portfolios. """Analyze a single company's patent portfolio.
Processes companies concurrently for improved performance. This endpoint retrieves recent patents for the specified company,
Limited to 20 companies per request. parses them, and uses AI to generate a comprehensive analysis.
Args: Args:
request: List of company names and optional worker count company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns: Returns:
Batch results with individual company analyses and summary statistics Analysis results including patent count, AI insights, and success status
""" """
_validate_model(request.model) _validate_model(model)
if not _analyzer: if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized") raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer.analyze_companies( result = _analyzer._analyze_company_safe(company_name, model=model)
companies=request.companies, return _convert_result(result)
max_workers=request.max_workers,
model=request.model,
)
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient": def _get_job_db() -> "DatabaseClient":
@@ -1000,7 +1282,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
async def analyze_companies_async( async def analyze_companies_async(
request: BatchAnalysisRequest, request: BatchAnalysisRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Start an asynchronous batch analysis job. """Start an asynchronous batch analysis job.
@@ -1020,7 +1302,7 @@ async def analyze_companies_async(
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
db = _get_job_db() db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) job_row = db.create_job(job_id=job_id, total_companies=len(request.companies), owner_id=current_user.id)
background_tasks.add_task( background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers, request.model _run_batch_job, job_id, request.companies, request.max_workers, request.model
@@ -1032,9 +1314,9 @@ async def analyze_companies_async(
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
async def get_job_status( async def get_job_status(
job_id: str, job_id: str,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""Get the status of a background analysis job. """Get the status of a background analysis job (scoped to current user).
Args: Args:
job_id: The job ID returned from the async batch endpoint job_id: The job ID returned from the async batch endpoint
@@ -1043,7 +1325,7 @@ async def get_job_status(
Current job status including progress and results when complete Current job status including progress and results when complete
""" """
db = _get_job_db() db = _get_job_db()
job_row = db.get_job(job_id) job_row = db.get_job(job_id, owner_id=current_user.id)
if not job_row: if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found") raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
@@ -1057,14 +1339,14 @@ async def list_jobs(
str | None, str | None,
Query(description="Filter by status: pending, running, completed, failed"), Query(description="Filter by status: pending, running, completed, failed"),
] = None, ] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 10, limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[ cursor: Annotated[
str | None, str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"), Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None, ] = None,
_: UserResponse = Depends(get_current_user), current_user: UserResponse = Depends(get_current_user),
): ):
"""List analysis jobs with cursor-based pagination. """List analysis jobs with cursor-based pagination (scoped to current user).
Pass ``limit`` to control page size. The response includes a ``next_cursor`` Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page. field; pass it back as the ``cursor`` query parameter to fetch the next page.
@@ -1083,7 +1365,7 @@ async def list_jobs(
""" """
db = _get_job_db() db = _get_job_db()
# Fetch one extra to determine if there is a next page # Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(job_rows) > limit has_next = len(job_rows) > limit
if has_next: if has_next:
+207 -29
View File
@@ -196,7 +196,7 @@ class DatabaseClient:
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies ( CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL, company_name VARCHAR(255) NOT NULL,
last_patent_count INTEGER DEFAULT 0, last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP, last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -221,6 +221,68 @@ class DatabaseClient:
ON alerts(company_name) ON alerts(company_name)
""") """)
# ---- Multi-tenant: add owner_id columns if missing ----
cursor.execute("""
DO $$
BEGIN
-- llm_messages.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- jobs.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- tracked_companies.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# Indexes for owner_id filtering
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# Drop the old unique constraint on company_name alone (if it exists)
# and replace with a per-owner unique constraint so different users
# can track the same company independently.
cursor.execute("""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
self.conn.commit() self.conn.commit()
@staticmethod @staticmethod
@@ -289,6 +351,7 @@ class DatabaseClient:
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
token_usage: Optional[Dict] = None, token_usage: Optional[Dict] = None,
is_cached: bool = False, is_cached: bool = False,
owner_id: Optional[int] = None,
) -> int: ) -> int:
"""Store an LLM message exchange in the database. """Store an LLM message exchange in the database.
@@ -301,6 +364,7 @@ class DatabaseClient:
metadata: Additional metadata as dict metadata: Additional metadata as dict
token_usage: Token usage information token_usage: Token usage information
is_cached: Whether this response was served from cache is_cached: Whether this response was served from cache
owner_id: ID of the user who owns this record
Returns: Returns:
The ID of the inserted record The ID of the inserted record
@@ -312,8 +376,8 @@ class DatabaseClient:
cursor.execute( cursor.execute(
""" """
INSERT INTO llm_messages INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached, owner_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id RETURNING id
""", """,
( (
@@ -326,6 +390,7 @@ class DatabaseClient:
json.dumps(metadata) if metadata else None, json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None, json.dumps(token_usage) if token_usage else None,
is_cached, is_cached,
owner_id,
), ),
) )
@@ -340,6 +405,7 @@ class DatabaseClient:
analysis_type: Optional[str] = None, analysis_type: Optional[str] = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
owner_id: Optional[int] = None,
) -> List[Dict]: ) -> List[Dict]:
"""Retrieve messages from the database. """Retrieve messages from the database.
@@ -348,6 +414,7 @@ class DatabaseClient:
analysis_type: Filter by analysis type analysis_type: Filter by analysis type
limit: Maximum number of records to return limit: Maximum number of records to return
offset: Number of records to skip offset: Number of records to skip
owner_id: Filter by owner (None returns all, for admin use)
Returns: Returns:
List of message dictionaries List of message dictionaries
@@ -355,6 +422,10 @@ class DatabaseClient:
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
if company_name: if company_name:
query += " AND company_name = %s" query += " AND company_name = %s"
params.append(company_name) params.append(company_name)
@@ -371,52 +442,110 @@ class DatabaseClient:
cursor.execute(query, params) cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def get_analytics(self, days: int = 30) -> Dict: def list_analyses(
self,
company_name: Optional[str] = None,
limit: int = 50,
cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]:
"""List analysis results with cursor-based pagination.
Args:
company_name: Optional filter by company name.
limit: Maximum number of records to return.
cursor: Opaque cursor (``timestamp|id``) from a previous response.
owner_id: Filter by owner (None returns all, for admin use).
Returns:
List of analysis dicts ordered by timestamp descending.
"""
conditions: list[str] = ["is_cached = FALSE"]
params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if company_name:
conditions.append("LOWER(company_name) = LOWER(%s)")
params.append(company_name)
if cursor:
try:
ts_str, cursor_id = cursor.rsplit("|", 1)
conditions.append("(timestamp, id) < (%s, %s)")
params.extend([ts_str, int(cursor_id)])
except (ValueError, TypeError):
pass # Ignore malformed cursors; return from start
query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages"
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY timestamp DESC, id DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, params)
return [dict(row) for row in cur.fetchall()]
def get_analytics(self, days: int = 30, owner_id: Optional[int] = None) -> Dict:
"""Get analytics on message usage. """Get analytics on message usage.
Args: Args:
days: Number of days to look back days: Number of days to look back
owner_id: Filter by owner (None returns all, for admin use)
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
owner_filter = ""
owner_params: list = []
if owner_id is not None:
owner_filter = " AND owner_id = %s"
owner_params = [owner_id]
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages # Total messages
cursor.execute( cursor.execute(
""" f"""
SELECT COUNT(*) as total_messages SELECT COUNT(*) as total_messages
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
""", """,
(days,), (days, *owner_params),
) )
total = cursor.fetchone()["total_messages"] total = cursor.fetchone()["total_messages"]
# Messages by company # Messages by company
cursor.execute( cursor.execute(
""" f"""
SELECT company_name, COUNT(*) as count SELECT company_name, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY company_name GROUP BY company_name
ORDER BY count DESC ORDER BY count DESC
LIMIT 10 LIMIT 10
""", """,
(days,), (days, *owner_params),
) )
by_company = cursor.fetchall() by_company = cursor.fetchall()
# Messages by type # Messages by type
cursor.execute( cursor.execute(
""" f"""
SELECT analysis_type, COUNT(*) as count SELECT analysis_type, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY analysis_type GROUP BY analysis_type
ORDER BY count DESC ORDER BY count DESC
""", """,
(days,), (days, *owner_params),
) )
by_type = cursor.fetchall() by_type = cursor.fetchall()
@@ -514,12 +643,14 @@ class DatabaseClient:
self, self,
job_id: str, job_id: str,
total_companies: int, total_companies: int,
owner_id: Optional[int] = None,
) -> Dict: ) -> Dict:
"""Create a new job record. """Create a new job record.
Args: Args:
job_id: Unique job identifier job_id: Unique job identifier
total_companies: Number of companies in the batch total_companies: Number of companies in the batch
owner_id: ID of the user who owns this job
Returns: Returns:
Job dict Job dict
@@ -528,11 +659,11 @@ class DatabaseClient:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies, owner_id)
VALUES (%s, 'pending', 0, %s, 0) VALUES (%s, 'pending', 0, %s, 0, %s)
RETURNING * RETURNING *
""", """,
(job_id, total_companies), (job_id, total_companies, owner_id),
) )
job = cursor.fetchone() job = cursor.fetchone()
conn.commit() conn.commit()
@@ -585,11 +716,22 @@ class DatabaseClient:
conn.commit() conn.commit()
return dict(job) if job else None return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]: def get_job(self, job_id: str, owner_id: Optional[int] = None) -> Optional[Dict]:
"""Get a job by ID.""" """Get a job by ID.
Args:
job_id: Job identifier.
owner_id: When provided, only return the job if it belongs to this owner.
"""
query = "SELECT * FROM jobs WHERE job_id = %s"
params: list = [job_id]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) cursor.execute(query, params)
job = cursor.fetchone() job = cursor.fetchone()
return dict(job) if job else None return dict(job) if job else None
@@ -598,6 +740,7 @@ class DatabaseClient:
status: Optional[str] = None, status: Optional[str] = None,
limit: int = 10, limit: int = 10,
cursor: Optional[str] = None, cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]: ) -> List[Dict]:
"""List jobs with optional status filter and cursor-based pagination. """List jobs with optional status filter and cursor-based pagination.
@@ -607,6 +750,7 @@ class DatabaseClient:
cursor: Opaque cursor (``created_at|job_id``) from a previous cursor: Opaque cursor (``created_at|job_id``) from a previous
response. When provided, only jobs older than the cursor are response. When provided, only jobs older than the cursor are
returned. returned.
owner_id: Filter by owner (None returns all, for admin use).
Returns: Returns:
List of job dicts ordered by created_at descending. List of job dicts ordered by created_at descending.
@@ -614,6 +758,10 @@ class DatabaseClient:
conditions: list[str] = [] conditions: list[str] = []
params: list = [] params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if status: if status:
conditions.append("status = %s") conditions.append("status = %s")
params.append(status) params.append(status)
@@ -860,14 +1008,21 @@ class DatabaseClient:
# Tracked Companies Methods # Tracked Companies Methods
def add_tracked_company(self, company_name: str) -> Optional[Dict]: def add_tracked_company(
"""Add a company to the tracking list.""" self, company_name: str, owner_id: Optional[int] = None
) -> Optional[Dict]:
"""Add a company to the tracking list.
Args:
company_name: Company name to track.
owner_id: ID of the user who owns this tracked company.
"""
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try: try:
cursor.execute( cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", "INSERT INTO tracked_companies (company_name, owner_id) VALUES (%s, %s) RETURNING *",
(company_name,), (company_name, owner_id),
) )
row = cursor.fetchone() row = cursor.fetchone()
conn.commit() conn.commit()
@@ -876,22 +1031,45 @@ class DatabaseClient:
conn.rollback() conn.rollback()
return None return None
def remove_tracked_company(self, company_name: str) -> bool: def remove_tracked_company(
"""Remove a company from the tracking list.""" self, company_name: str, owner_id: Optional[int] = None
) -> bool:
"""Remove a company from the tracking list.
Args:
company_name: Company name to remove.
owner_id: When provided, only remove if owned by this user.
"""
query = "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)"
params: list = [company_name]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute( cursor.execute(query, params)
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]: def list_tracked_companies(
"""List all tracked companies.""" self, owner_id: Optional[int] = None
) -> List[Dict]:
"""List tracked companies.
Args:
owner_id: Filter by owner (None returns all, for admin/scheduler use).
"""
query = "SELECT * FROM tracked_companies"
params: list = []
if owner_id is not None:
query += " WHERE owner_id = %s"
params.append(owner_id)
query += " ORDER BY company_name"
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def update_tracked_company( def update_tracked_company(
+9
View File
@@ -11,6 +11,7 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics'; import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About'; import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers'; import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare'; import { Compare } from './pages/Compare';
const queryClient = new QueryClient({ const queryClient = new QueryClient({
@@ -56,6 +57,14 @@ function App() {
</ProtectedRoute> </ProtectedRoute>
} }
/> />
<Route
path="/admin/rate-limits"
element={
<ProtectedRoute requireAdmin>
<AdminRateLimits />
</ProtectedRoute>
}
/>
</Route> </Route>
{/* Default redirect */} {/* Default redirect */}
+31
View File
@@ -201,6 +201,32 @@ export const analyticsApi = {
}, },
}; };
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API // Admin API
export const adminApi = { export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => { listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -216,6 +242,11 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => { deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`); await api.delete(`/admin/users/${userId}`);
}, },
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
}; };
export default api; export default api;
+2 -1
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext'; import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext'; import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react'; import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react';
export function Layout() { export function Layout() {
const { user, isAdmin, logout } = useAuth(); const { user, isAdmin, logout } = useAuth();
@@ -23,6 +23,7 @@ export function Layout() {
if (isAdmin) { if (isAdmin) {
navItems.push({ to: '/admin/users', icon: Users, label: 'Users' }); navItems.push({ to: '/admin/users', icon: Users, label: 'Users' });
navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' });
} }
return ( return (
+240
View File
@@ -0,0 +1,240 @@
import { useState } from 'react';
import { useQuery } from '@tanstack/react-query';
import { adminApi } from '../api/client';
import type { RateLimitStatsResponse } from '../api/client';
import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react';
const REFRESH_OPTIONS = [
{ label: '15s', value: 15_000 },
{ label: '30s', value: 30_000 },
{ label: '1m', value: 60_000 },
{ label: 'Off', value: 0 },
];
export function AdminRateLimits() {
const [refreshInterval, setRefreshInterval] = useState(30_000);
const { data, isLoading, isError, dataUpdatedAt } = useQuery<RateLimitStatsResponse>({
queryKey: ['admin-rate-limits'],
queryFn: () => adminApi.getRateLimits(),
refetchInterval: refreshInterval || false,
});
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div>
);
}
if (isError) {
return (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load rate limit statistics.</span>
</div>
);
}
const maxThrottledCount = data?.throttled_over_time?.length
? Math.max(...data.throttled_over_time.map((b) => b.count))
: 0;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between flex-wrap gap-4">
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Rate Limiting Dashboard
</h2>
<p className="text-text-secondary">Monitor API rate limits and throttled requests.</p>
</div>
<div className="flex items-center gap-3">
{/* Last updated */}
{dataUpdatedAt > 0 && (
<span className="text-xs text-text-secondary flex items-center gap-1">
<Clock size={12} />
Updated {new Date(dataUpdatedAt).toLocaleTimeString()}
</span>
)}
{/* Refresh interval selector */}
<div className="flex items-center gap-1 bg-bg-card/60 border border-primary/15 rounded-xl p-1">
<RefreshCw size={14} className="text-text-secondary ml-2" />
{REFRESH_OPTIONS.map((opt) => (
<button
key={opt.value}
onClick={() => setRefreshInterval(opt.value)}
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
refreshInterval === opt.value
? 'bg-primary text-white'
: 'text-text-secondary hover:text-text-primary hover:bg-bg-card-hover'
}`}
>
{opt.label}
</button>
))}
</div>
</div>
</div>
{/* Summary cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<Activity size={18} className="text-primary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-error/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-error" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Throttled (24h)
</span>
</div>
<div className="text-3xl font-bold text-error">
{data?.throttled_24h ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-secondary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-secondary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rate-Limited Endpoints
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.length ?? 0}
</div>
</div>
</div>
{/* Throttled over time chart (simple bar chart) */}
{data?.throttled_over_time && data.throttled_over_time.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider mb-4">
Throttled Requests Over Time (Last 24h)
</h3>
<div className="flex items-end gap-1 h-32">
{data.throttled_over_time.map((bucket) => {
const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0;
const hour = new Date(bucket.timestamp).getHours();
return (
<div key={bucket.timestamp} className="flex-1 flex flex-col items-center gap-1">
<span className="text-xs text-text-secondary">{bucket.count}</span>
<div
className="w-full bg-error/70 rounded-t-sm min-h-[2px] transition-all"
style={{ height: `${Math.max(height, 2)}%` }}
title={`${bucket.timestamp}: ${bucket.count} throttled`}
/>
<span className="text-[10px] text-text-secondary">{hour}:00</span>
</div>
);
})}
</div>
</div>
)}
{/* Per-endpoint table */}
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Limit
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data?.rate_limits.map((rl) => (
<tr key={rl.endpoint} className="hover:bg-bg-card-hover/50 transition-colors">
<td className="px-6 py-4 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-4">
<span className="inline-flex px-2 py-0.5 rounded-full text-xs font-medium bg-primary/10 text-primary border border-primary/20">
{rl.limit}
</span>
</td>
<td className="px-6 py-4 text-right text-text-primary font-semibold">
{rl.total_requests}
</td>
<td className="px-6 py-4 text-right">
<span className={rl.rejected_requests > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{rl.rejected_requests}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Per-IP breakdown */}
{data?.rate_limits.some((rl) => rl.by_ip.length > 0) && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="px-6 py-4 border-b border-primary/10">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Per-IP Breakdown
</h3>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
IP Address
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data.rate_limits.flatMap((rl) =>
rl.by_ip.map((ipEntry) => (
<tr
key={`${rl.endpoint}-${ipEntry.ip}`}
className="hover:bg-bg-card-hover/50 transition-colors"
>
<td className="px-6 py-3 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-3 font-mono text-sm text-text-secondary">{ipEntry.ip}</td>
<td className="px-6 py-3 text-right text-text-primary">{ipEntry.total}</td>
<td className="px-6 py-3 text-right">
<span className={ipEntry.rejected > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{ipEntry.rejected}
</span>
</td>
</tr>
))
)}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}
+132
View File
@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Migration: add owner_id columns and backfill existing rows.
This script adds an ``owner_id`` column (FK to ``users``) to the
``llm_messages``, ``jobs``, and ``tracked_companies`` tables, then
backfills all existing rows with ``owner_id = 1`` (the default admin user).
It also replaces the old global UNIQUE constraint on
``tracked_companies.company_name`` with a per-owner unique index so that
different users can independently track the same company.
Usage:
python scripts/migrate_add_owner_id.py
The script is idempotent — running it multiple times is safe.
"""
import os
import sys
import psycopg2
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/sparc",
)
DEFAULT_OWNER_ID = 1
def run_migration():
"""Execute the migration."""
conn = psycopg2.connect(DATABASE_URL)
conn.autocommit = False
try:
with conn.cursor() as cur:
# ---------- 1. Add owner_id columns if missing ----------
cur.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# ---------- 2. Backfill owner_id = DEFAULT_OWNER_ID ----------
cur.execute(
"UPDATE llm_messages SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
messages_updated = cur.rowcount
print(f" llm_messages: backfilled {messages_updated} rows")
cur.execute(
"UPDATE jobs SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
jobs_updated = cur.rowcount
print(f" jobs: backfilled {jobs_updated} rows")
cur.execute(
"UPDATE tracked_companies SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
tracked_updated = cur.rowcount
print(f" tracked_companies: backfilled {tracked_updated} rows")
# ---------- 3. Create indexes ----------
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# ---------- 4. Replace unique constraint on tracked_companies ----------
cur.execute("""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cur.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
conn.commit()
print("Migration completed successfully.")
except Exception:
conn.rollback()
print("Migration FAILED — rolled back.", file=sys.stderr)
raise
finally:
conn.close()
if __name__ == "__main__":
print(f"Running owner_id migration against {DATABASE_URL.split('@')[-1]} ...")
run_migration()
+74 -18
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints.""" """Tests for FastAPI web service endpoints."""
from datetime import datetime from datetime import datetime, timezone
from unittest.mock import Mock from unittest.mock import Mock, MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from SPARC.api import app from SPARC.api import app
from SPARC.auth import create_access_token
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -16,6 +17,22 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
@pytest.fixture @pytest.fixture
def mock_analyzer(mocker): def mock_analyzer(mocker):
"""Mock the global analyzer.""" """Mock the global analyzer."""
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
return mock return mock
def _auth_header(user_id=1, email="user@test.com", role="user"):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_id, email, role)
return {"Authorization": f"Bearer {token}"}
class TestHealthEndpoint: class TestHealthEndpoint:
"""Test health check endpoint.""" """Test health check endpoint."""
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia") response = client.get("/analyze/nvidia", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown") response = client.get("/analyze/unknown", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2}, json={"companies": ["nvidia", "amd"], "max_workers": 2},
headers=_auth_header(),
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
def test_batch_analysis_validation(self, client): def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation.""" """Test batch analysis request validation."""
# Empty companies list # Empty companies list
response = client.post("/analyze/batch", json={"companies": []}) response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
assert response.status_code == 422 assert response.status_code == 422
# Too many companies # Too many companies
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]}, json={"companies": [f"company{i}" for i in range(25)]},
headers=_auth_header(),
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10}, json={"companies": ["nvidia"], "max_workers": 10},
headers=_auth_header(),
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
class TestAsyncBatchEndpoint: class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint.""" """Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer): @patch("SPARC.api._get_job_db")
"""Test async endpoint creates a job.""" def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
"""Test async endpoint creates a job with owner_id."""
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j1",
"status": "pending",
"progress": 0,
"total_companies": 2,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": 1,
}
mock_get_db.return_value = job_db
response = client.post( response = client.post(
"/analyze/batch/async", "/analyze/batch/async",
json={"companies": ["nvidia", "amd"]}, json={"companies": ["nvidia", "amd"]},
headers=_auth_header(),
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
assert data["status"] == "pending" assert data["status"] == "pending"
assert data["total_companies"] == 2 assert data["total_companies"] == 2
assert data["progress"] == 0 assert data["progress"] == 0
# Verify owner_id was passed
job_db.create_job.assert_called_once()
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
class TestJobEndpoints: class TestJobEndpoints:
"""Test job management endpoints.""" """Test job management endpoints."""
def test_get_job_not_found(self, client): @patch("SPARC.api._get_job_db")
def test_get_job_not_found(self, mock_get_db, client):
"""Test getting nonexistent job.""" """Test getting nonexistent job."""
response = client.get("/jobs/nonexistent") job_db = MagicMock()
job_db.get_job.return_value = None
mock_get_db.return_value = job_db
response = client.get("/jobs/nonexistent", headers=_auth_header())
assert response.status_code == 404 assert response.status_code == 404
def test_list_jobs(self, client, mocker): @patch("SPARC.api._get_job_db")
def test_list_jobs(self, mock_get_db, client):
"""Test listing jobs.""" """Test listing jobs."""
# Clear existing jobs job_db = MagicMock()
mocker.patch.dict("SPARC.api._jobs", {}, clear=True) job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs") response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker): @patch("SPARC.api._get_job_db")
def test_list_jobs_with_filter(self, mock_get_db, client):
"""Test listing jobs with status filter.""" """Test listing jobs with status filter."""
response = client.get("/jobs?status=completed") job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs?status=completed", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
@@ -189,7 +244,7 @@ class TestModelValidation:
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400.""" """GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model") response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
assert response.status_code == 400 assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"] assert "Unsupported model" in response.json()["detail"]
@@ -205,7 +260,7 @@ class TestModelValidation:
) )
mock_analyzer._analyze_company_safe.return_value = mock_result mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet") response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
assert response.status_code == 200 assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer): def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
@@ -213,6 +268,7 @@ class TestModelValidation:
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
headers=_auth_header(),
) )
assert response.status_code == 400 assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"] assert "Unsupported model" in response.json()["detail"]
+1
View File
@@ -5,6 +5,7 @@ Covers issue #1655:
- GET /export/{company_name}/pdf (PDF export) - GET /export/{company_name}/pdf (PDF export)
All tests mock the database layer and use JWT auth fixtures from test_auth patterns. All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
Export queries are now scoped to the current user's owner_id.
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
+281
View File
@@ -0,0 +1,281 @@
"""Cross-tenant isolation tests for multi-tenant support.
Verifies that:
- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs
- Admin users can access all data via admin endpoints
- owner_id is correctly set on new resources
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user(user_id, email, role="user"):
return {
"id": user_id,
"email": email,
"role": role,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
USER_A = _make_user(10, "alice@test.com")
USER_B = _make_user(20, "bob@test.com")
ADMIN = _make_user(1, "admin@test.com", role="admin")
def _header_for(user):
token = create_access_token(user["id"], user["email"], user["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db():
"""Mock DB returning the correct user based on user_id."""
db = MagicMock()
def _get_user_by_id(uid):
for u in [USER_A, USER_B, ADMIN]:
if u["id"] == uid:
return u
return None
db.get_user_by_id.side_effect = _get_user_by_id
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
# ==================== Tracked Companies Isolation ====================
class TestTrackedCompanyIsolation:
"""User A's tracked companies are invisible to User B."""
def test_user_a_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User A's companies."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": USER_A["id"]},
]
response = client.get("/tracked", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"])
def test_user_b_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User B's companies."""
mock_db.list_tracked_companies.return_value = []
response = client.get("/tracked", headers=_header_for(USER_B))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"])
def test_user_a_add_sets_owner(self, client, mock_db):
"""POST /tracked sets owner_id to User A."""
mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10}
response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"])
def test_user_b_cannot_remove_user_a_company(self, client, mock_db):
"""DELETE /tracked/{name} filters by owner, so B can't remove A's company."""
mock_db.remove_tracked_company.return_value = False # not found for B
response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B))
assert response.status_code == 404
mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"])
# ==================== Job Isolation ====================
class TestJobIsolation:
"""User A's jobs are invisible to User B."""
def test_user_a_get_own_job(self, client, mock_db):
"""GET /jobs/{id} scoped to User A returns the job."""
mock_db.get_job.return_value = None # mock via _get_job_db
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = {
"job_id": "j1",
"status": "completed",
"progress": 100,
"total_companies": 1,
"completed_companies": 1,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_A))
assert response.status_code == 200
job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"])
def test_user_b_cannot_see_user_a_job(self, client, mock_db):
"""GET /jobs/{id} returns 404 when User B tries to access User A's job."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = None # not found for B's owner_id
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_B))
assert response.status_code == 404
job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"])
def test_list_jobs_scoped_to_user(self, client, mock_db):
"""GET /jobs filters by owner_id."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
def test_async_job_created_with_owner(self, client, mock_db):
"""POST /analyze/batch/async creates job with current user's owner_id."""
mock_analyzer = MagicMock()
with patch("SPARC.api._analyzer", mock_analyzer), \
patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j2",
"status": "pending",
"progress": 0,
"total_companies": 1,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia"]},
headers=_header_for(USER_A),
)
assert response.status_code == 200
create_kwargs = job_db.create_job.call_args
assert create_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Analysis Listing Isolation ====================
class TestAnalysisListIsolation:
"""GET /analyze/batch scoped to current user."""
def test_list_analyses_scoped_to_user(self, client, mock_db):
"""GET /analyze/batch passes owner_id to db.list_analyses."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/analyze/batch", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Admin Cross-Tenant Access ====================
class TestAdminCrossTenantAccess:
"""Admin endpoints return data from all tenants (no owner_id filter)."""
def test_admin_list_tracked_all_tenants(self, client, mock_db):
"""GET /admin/tracked returns all companies (no owner_id filter)."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": 10},
{"company_name": "BobCo", "owner_id": 20},
]
response = client.get("/admin/tracked", headers=_header_for(ADMIN))
assert response.status_code == 200
# Should be called without owner_id filter
mock_db.list_tracked_companies.assert_called_with()
def test_admin_list_analyses_all_tenants(self, client, mock_db):
"""GET /admin/analyses returns all analyses (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/analyses", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
# No owner_id should be passed
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_list_jobs_all_tenants(self, client, mock_db):
"""GET /admin/jobs returns all jobs (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/jobs", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_remove_tracked_any_owner(self, client, mock_db):
"""DELETE /admin/tracked/{name} removes without owner filter."""
mock_db.remove_tracked_company.return_value = True
response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN))
assert response.status_code == 200
# Called without owner_id
mock_db.remove_tracked_company.assert_called_with("SomeCo")
def test_regular_user_cannot_access_admin_analyses(self, client, mock_db):
"""Regular user gets 403 on /admin/analyses."""
response = client.get("/admin/analyses", headers=_header_for(USER_A))
assert response.status_code == 403
def test_regular_user_cannot_access_admin_jobs(self, client, mock_db):
"""Regular user gets 403 on /admin/jobs."""
response = client.get("/admin/jobs", headers=_header_for(USER_A))
assert response.status_code == 403
# ==================== Analytics Isolation ====================
class TestAnalyticsIsolation:
"""GET /analytics scoped to current user."""
def test_analytics_scoped_to_user(self, client, mock_db):
"""GET /analytics passes owner_id to db.get_analytics."""
mock_db.get_analytics.return_value = {
"total_messages": 5,
"by_company": [],
"by_type": [],
"period_days": 30,
}
response = client.get("/analytics", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"])
+191
View File
@@ -0,0 +1,191 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
return {
"id": id_,
"company_name": company,
"analysis_type": "patent_portfolio",
"model": "openai/gpt-4o",
"response": f"Analysis for {company}",
"timestamp": ts,
}
def _make_job_row(job_id: str, minutes_ago: int = 0, status: str = "completed"):
"""Create a fake job row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
return {
"job_id": job_id,
"status": status,
"progress": 100 if status == "completed" else 0,
"total_companies": 1,
"completed_companies": 1 if status == "completed" else 0,
"result": None,
"error": None,
"created_at": ts,
}
class TestAnalyzeBatchGetPagination:
"""Test cursor-based pagination on GET /analyze/batch."""
@patch("SPARC.api._get_job_db")
def test_returns_items_and_no_cursor_when_less_than_limit(self, mock_get_db, client):
"""When fewer results than limit, next_cursor should be null."""
db = Mock()
db.list_analyses.return_value = [
_make_analysis_row(1, minutes_ago=10),
_make_analysis_row(2, minutes_ago=20),
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
assert data["next_cursor"] is None
@patch("SPARC.api._get_job_db")
def test_returns_cursor_when_more_results_exist(self, mock_get_db, client):
"""When more results exist than limit, next_cursor should be set."""
db = Mock()
# Return limit+1 rows to simulate more data
rows = [_make_analysis_row(i, minutes_ago=i) for i in range(4)]
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["next_cursor"] is not None
@patch("SPARC.api._get_job_db")
def test_cursor_passed_to_db(self, mock_get_db, client):
"""The cursor query param should be forwarded to the database layer."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
(call_kwargs[1].get("cursor") == "2025-01-01T00:00:00|42" if len(call_kwargs) > 1 else False)
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
"""Default limit should be 50."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
def test_company_name_filter(self, mock_get_db, client):
"""The company_name filter should be forwarded to the database."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
@patch("SPARC.api._get_job_db")
def test_empty_result_set(self, mock_get_db, client):
"""Empty result set returns empty items and null cursor."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["next_cursor"] is None
class TestJobsPaginationDefaults:
"""Test that /jobs endpoint uses updated defaults."""
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
"""Default limit should now be 50."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
def test_limit_200_accepted(self, mock_get_db, client):
"""Limit of exactly 200 should be accepted."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200
+178
View File
@@ -0,0 +1,178 @@
"""Tests for the /admin/rate-limits endpoint."""
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from SPARC import api
from SPARC.api import app
from SPARC.auth import UserResponse
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_stats():
"""Reset rate limit stats between tests."""
api._rate_limit_stats.clear()
api._rejected_log.clear()
yield
api._rate_limit_stats.clear()
api._rejected_log.clear()
def _mock_admin():
"""Return a mock admin user."""
return UserResponse(id=1, email="admin@test.com", role="admin", created_at="2025-01-01T00:00:00")
def _mock_user():
"""Return a mock non-admin user."""
return UserResponse(id=2, email="user@test.com", role="user", created_at="2025-01-01T00:00:00")
class TestRateLimitAdminEndpoint:
"""Test GET /admin/rate-limits."""
def test_admin_can_access(self, client):
"""Admin users should be able to access the rate-limits endpoint."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
assert "rate_limits" in data
assert isinstance(data["rate_limits"], list)
finally:
app.dependency_overrides.clear()
def test_non_admin_rejected(self, client):
"""Non-admin users should get 401/403."""
response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403)
def test_returns_configured_endpoints(self, client):
"""Should list all rate-limited endpoints."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
endpoints = [rl["endpoint"] for rl in data["rate_limits"]]
assert "/auth/register" in endpoints
assert "/auth/login" in endpoints
finally:
app.dependency_overrides.clear()
def test_empty_state_shows_zero_counts(self, client):
"""When no requests have been made, counts should be zero."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0
assert rl["by_ip"] == []
assert data["throttled_24h"] == 0
assert data["throttled_over_time"] == []
finally:
app.dependency_overrides.clear()
def test_tracks_requests(self, client):
"""After making requests, the stats should reflect them."""
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "192.168.1.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
assert login_stats["total_requests"] == 3
assert login_stats["rejected_requests"] == 1
finally:
app.dependency_overrides.clear()
def test_includes_limit_config(self, client):
"""Each endpoint entry should include the rate limit config string."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert "limit" in rl
assert isinstance(rl["limit"], str)
finally:
app.dependency_overrides.clear()
def test_per_ip_breakdown(self, client):
"""Stats should include per-IP breakdown with total and rejected counts."""
api._track_rate_limit_request("/auth/login", "10.0.0.1")
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/login", "10.0.0.2")
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
by_ip = login_stats["by_ip"]
assert len(by_ip) == 2
ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1")
assert ip1["total"] == 2
assert ip1["rejected"] == 1
ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2")
assert ip2["total"] == 1
assert ip2["rejected"] == 0
finally:
app.dependency_overrides.clear()
def test_throttled_24h_count(self, client):
"""Should report total throttled requests in the last 24 hours."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert data["throttled_24h"] == 2
finally:
app.dependency_overrides.clear()
def test_throttled_over_time_structure(self, client):
"""Throttled-over-time should be a list of {timestamp, count} buckets."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert len(data["throttled_over_time"]) >= 1
entry = data["throttled_over_time"][0]
assert "timestamp" in entry
assert "count" in entry
assert entry["count"] >= 1
finally:
app.dependency_overrides.clear()
def test_response_shape_matches_contract(self, client):
"""The full response should match the expected shape for the frontend."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
# Top-level keys
assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"}
# Each rate_limit entry
for rl in data["rate_limits"]:
assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"}
finally:
app.dependency_overrides.clear()
+71 -10
View File
@@ -1,17 +1,18 @@
"""Tests for tracked company admin endpoints and scheduler integration. """Tests for tracked company endpoints and scheduler integration.
Covers issue #1656: Covers:
- GET /admin/tracked (list tracked companies) - GET /tracked (user-scoped list)
- POST /admin/tracked (add a tracked company) - POST /tracked (user-scoped add)
- DELETE /admin/tracked/{company_name} (remove a tracked company) - DELETE /tracked/{company_name} (user-scoped remove)
- GET /admin/tracked (admin: all companies)
- POST /admin/tracked (admin: add)
- DELETE /admin/tracked/{company_name} (admin: remove any)
- GET /admin/alerts (list alerts) - GET /admin/alerts (list alerts)
- scheduler.run_scheduled_analysis() integration - scheduler.run_scheduled_analysis() integration
All tests mock the database layer and use JWT auth fixtures.
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import MagicMock, patch, call from unittest.mock import MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -125,7 +126,7 @@ class TestAddTrackedCompany:
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["company_name"] == "Intel" assert data["company_name"] == "Intel"
mock_db.add_tracked_company.assert_called_once_with("Intel") mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1)
def test_add_duplicate_returns_409(self, client, mock_db): def test_add_duplicate_returns_409(self, client, mock_db):
"""Adding an already-tracked company returns 409.""" """Adding an already-tracked company returns 409."""
@@ -141,7 +142,7 @@ class TestAddTrackedCompany:
assert "already tracked" in response.json()["detail"].lower() assert "already tracked" in response.json()["detail"].lower()
def test_add_tracked_requires_admin(self, client, mock_db): def test_add_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot add tracked companies.""" """Regular user cannot add tracked companies via admin endpoint."""
mock_db.get_user_by_id.return_value = { mock_db.get_user_by_id.return_value = {
"id": 2, "id": 2,
"email": "user@test.com", "email": "user@test.com",
@@ -215,6 +216,66 @@ class TestRemoveTrackedCompany:
assert response.status_code == 403 assert response.status_code == 403
# ---------- User-scoped tracked companies ----------
class TestUserScopedTrackedCompanies:
"""Tests for /tracked user-scoped endpoints."""
def test_user_list_tracked(self, client, mock_db):
"""Regular user can list their own tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.list_tracked_companies.return_value = [
{"company_name": "AMD", "owner_id": 2},
]
response = client.get("/tracked", headers=_user_header())
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=2)
def test_user_add_tracked(self, client, mock_db):
"""Regular user can add a company to their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.add_tracked_company.return_value = {
"company_name": "Intel",
"owner_id": 2,
}
response = client.post(
"/tracked",
json={"company_name": "Intel"},
headers=_user_header(),
)
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2)
def test_user_remove_tracked(self, client, mock_db):
"""Regular user can remove a company from their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.remove_tracked_company.return_value = True
response = client.delete("/tracked/Intel", headers=_user_header())
assert response.status_code == 200
mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2)
# ---------- GET /admin/alerts ---------- # ---------- GET /admin/alerts ----------
class TestListAlerts: class TestListAlerts: