Compare commits

..

2 Commits

Author SHA1 Message Date
agent-company e37859dabc Add multi-tenant support with owner_id isolation
- Add owner_id (FK to users) column to llm_messages, jobs, and
  tracked_companies tables via schema migration in initialize_schema()
- Filter all read/write operations by authenticated user's owner_id
  so users cannot see or modify each other's data
- Add user-scoped /tracked endpoints alongside existing admin ones
- Add admin-scoped /admin/analyses and /admin/jobs endpoints that
  return cross-tenant data without owner filtering
- Create migration script (scripts/migrate_add_owner_id.py) that
  backfills owner_id=1 for all existing rows
- Replace global UNIQUE on tracked_companies.company_name with
  per-owner unique index (company_name, owner_id)
- Fix route ordering: /analyze/batch and /analyze/patent routes now
  registered before /analyze/{company_name} to prevent path conflicts
- Update all existing API tests with proper auth headers and owner_id
  assertions
- Add comprehensive cross-tenant isolation test suite
  (tests/test_multi_tenant.py)

Closes leeworks-agents/SPARC#1677

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 16:04:58 +00:00
agent-company 3dfa651f2d Add rate limiting dashboard to admin panel
- Enhance GET /admin/rate-limits with per-IP breakdown, 24h throttled
  count, and hourly time-series of rejected requests
- Add _rejected_log deque for time-series tracking of throttled requests
- Add AdminRateLimits React page with auto-refresh (configurable 15s/30s/1m),
  summary cards, throttled-over-time bar chart, endpoint table, per-IP table
- Add TypeScript types (RateLimitStatsResponse) and adminApi.getRateLimits()
- Wire up /admin/rate-limits route and nav link (admin-only)
- Expand unit tests to 10 cases: auth, empty state, per-IP breakdown,
  throttled_24h count, time-series structure, response shape contract

Closes leeworks-agents/SPARC#1686

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:39:45 +00:00
16 changed files with 964 additions and 1522 deletions
+205 -348
View File
@@ -30,11 +30,9 @@ from SPARC.auth import (
close_db_client,
create_tokens,
decode_token,
generate_api_key,
get_current_admin,
get_current_user,
get_db_client,
hash_api_key,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -142,31 +140,6 @@ class HealthResponse(BaseModel):
timestamp: datetime
# Historical diff models
class AnalysisDiffResponse(BaseModel):
"""Response model for diffing two analysis runs of the same company."""
company_name: str
from_id: int
to_id: int
from_timestamp: datetime
to_timestamp: datetime
patent_count_delta: int
added_patents: list[str]
removed_patents: list[str]
changed_fields: dict[str, dict]
summary: str
class CompanyAnalysisHistoryItem(BaseModel):
"""A summary item from a company's analysis history."""
id: int
analysis_type: str | None = None
model: str | None = None
timestamp: datetime
# Auth request/response models
class RegisterRequest(BaseModel):
"""User registration request."""
@@ -414,92 +387,6 @@ async def get_me(current_user: UserResponse = Depends(get_current_user)):
return current_user
# ============== API Key Endpoints ==============
class CreateApiKeyRequest(BaseModel):
"""Request to create a new API key."""
label: str | None = Field(default=None, max_length=100, description="Optional label for the key")
class ApiKeyResponse(BaseModel):
"""Response after creating an API key (includes plaintext key)."""
id: int
key: str # plaintext key, shown only at creation time
label: str | None = None
created_at: datetime
class ApiKeyInfo(BaseModel):
"""API key metadata (no secret)."""
id: int
label: str | None = None
created_at: datetime
@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"])
async def create_api_key_endpoint(
body: CreateApiKeyRequest | None = None,
current_user: UserResponse = Depends(get_current_user),
):
"""Generate a new API key for the authenticated user.
The plaintext key is returned **only once** in the response.
Store it securely; it cannot be retrieved again.
"""
plaintext_key = generate_api_key()
key_hash = hash_api_key(plaintext_key)
db = get_db_client()
label = body.label if body else None
row = db.create_api_key(
user_id=current_user.id,
key_hash=key_hash,
label=label,
)
return ApiKeyResponse(
id=row["id"],
key=plaintext_key,
label=row["label"],
created_at=row["created_at"],
)
@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"])
async def list_api_keys_endpoint(
current_user: UserResponse = Depends(get_current_user),
):
"""List active API key IDs and labels for the authenticated user.
Does **not** return the secret keys.
"""
db = get_db_client()
keys = db.list_api_keys(current_user.id)
return [ApiKeyInfo(**k) for k in keys]
@app.delete("/auth/apikeys/{key_id}", tags=["Auth"])
async def revoke_api_key_endpoint(
key_id: int,
current_user: UserResponse = Depends(get_current_user),
):
"""Revoke (delete) an API key by its ID.
The key must belong to the authenticated user.
"""
db = get_db_client()
deleted = db.delete_api_key(key_id, current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="API key not found")
return {"message": "API key revoked"}
# ============== Admin Endpoints ==============
@@ -587,11 +474,46 @@ class TrackCompanyRequest(BaseModel):
company_name: CompanyName = Field(...)
@app.get("/tracked", tags=["Tracked Companies"])
async def list_my_tracked_companies(
current_user: UserResponse = Depends(get_current_user),
):
"""List tracked companies for the current user."""
db = get_db_client()
return db.list_tracked_companies(owner_id=current_user.id)
@app.post("/tracked", tags=["Tracked Companies"])
async def add_my_tracked_company(
request: TrackCompanyRequest,
current_user: UserResponse = Depends(get_current_user),
):
"""Add a company to the current user's tracked list."""
db = get_db_client()
result = db.add_tracked_company(request.company_name, owner_id=current_user.id)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/tracked/{company_name}", tags=["Tracked Companies"])
async def remove_my_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
current_user: UserResponse = Depends(get_current_user),
):
"""Remove a company from the current user's tracked list."""
db = get_db_client()
removed = db.remove_tracked_company(company_name, owner_id=current_user.id)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin),
):
"""List all tracked companies (admin only)."""
"""List all tracked companies across all users (admin only)."""
db = get_db_client()
return db.list_tracked_companies()
@@ -599,11 +521,11 @@ async def list_tracked_companies(
@app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company(
request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin),
current_admin: UserResponse = Depends(get_current_admin),
):
"""Add a company to the tracked list (admin only)."""
"""Add a company to the tracked list (admin only, owned by admin)."""
db = get_db_client()
result = db.add_tracked_company(request.company_name)
result = db.add_tracked_company(request.company_name, owner_id=current_admin.id)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@@ -614,7 +536,7 @@ async def remove_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
"""Remove a company from the tracked list (admin only, any owner)."""
db = get_db_client()
removed = db.remove_tracked_company(company_name)
if not removed:
@@ -695,17 +617,86 @@ async def list_alerts(
return db.list_alerts(limit=limit)
# ============== Admin-Scoped Data Endpoints ==============
@app.get("/admin/analyses", response_model=PaginatedAnalysisResponse, tags=["Admin"])
async def admin_list_analyses(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all analysis results across all users (admin only)."""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.get("/admin/jobs", response_model=PaginatedJobsResponse, tags=["Admin"])
async def admin_list_jobs(
status: Annotated[
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_admin),
):
"""List all jobs across all users (admin only)."""
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
has_next = len(job_rows) > limit
if has_next:
job_rows = job_rows[:limit]
items = [_job_row_to_status(row) for row in job_rows]
next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
# ============== Analytics Endpoint ==============
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
async def get_analytics(
days: int = Query(default=30, ge=1, le=365),
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Get analytics data (authenticated users only)."""
"""Get analytics data scoped to the current user."""
db = get_db_client()
analytics = db.get_analytics(days=days)
analytics = db.get_analytics(days=days, owner_id=current_user.id)
return AnalyticsResponse(
total_messages=analytics["total_messages"],
@@ -758,9 +749,9 @@ async def list_models():
@app.get("/analytics/trends", tags=["Analytics"])
async def get_analytics_trends(
days: int = Query(default=90, ge=7, le=365),
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Get trend data for patent analysis over time.
"""Get trend data for patent analysis over time (scoped to current user).
Returns two datasets:
- ``by_month``: analysis count per company per month
@@ -774,11 +765,14 @@ async def get_analytics_trends(
"""
db = get_db_client()
owner_filter = " AND owner_id = %s" if current_user else ""
owner_params = (current_user.id,) if current_user else ()
with db.get_conn() as conn:
with conn.cursor() as cur:
# Analyses per company per month
cur.execute(
"""
f"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
company_name,
@@ -787,16 +781,17 @@ async def get_analytics_trends(
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
AND company_name IS NOT NULL
{owner_filter}
GROUP BY month, company_name
ORDER BY month
""",
(days,),
(days, *owner_params),
)
by_month_rows = cur.fetchall()
# Analysis type distribution per month
cur.execute(
"""
f"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
analysis_type,
@@ -804,10 +799,11 @@ async def get_analytics_trends(
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
{owner_filter}
GROUP BY month, analysis_type
ORDER BY month
""",
(days,),
(days, *owner_params),
)
by_type_rows = cur.fetchall()
@@ -833,9 +829,9 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
"""Export analysis results for a company as a CSV file (scoped to current user).
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
@@ -850,7 +846,7 @@ async def export_company_csv(
import io
db = get_db_client()
# Query all non-cached analysis results for this company
# Query all non-cached analysis results for this company owned by current user
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
@@ -858,9 +854,10 @@ async def export_company_csv(
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC
""",
(company_name,),
(company_name, current_user.id),
)
rows = cur.fetchall()
@@ -885,9 +882,9 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
"""Export analysis results for a company as a formatted PDF report (scoped to current user).
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
@@ -921,9 +918,10 @@ async def export_company_pdf(
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
AND owner_id = %s
ORDER BY timestamp DESC
""",
(company_name,),
(company_name, current_user.id),
)
rows = cur.fetchall()
@@ -1052,209 +1050,6 @@ async def health_check():
)
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns:
Analysis results including patent count, AI insights, and success status
"""
_validate_model(model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result)
def _extract_patent_ids(response_text: str) -> set[str]:
"""Extract patent IDs from an analysis response text.
Looks for patterns like US-12345678-B2, US12345678B2, etc.
"""
import re
pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?"
return set(re.findall(pattern, response_text or ""))
def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse:
"""Compute a structured diff between two analysis records."""
from_patents = _extract_patent_ids(from_rec.get("response", "") or "")
to_patents = _extract_patent_ids(to_rec.get("response", "") or "")
added = sorted(to_patents - from_patents)
removed = sorted(from_patents - to_patents)
patent_count_delta = len(to_patents) - len(from_patents)
changed_fields: dict[str, dict] = {}
if from_rec.get("model") != to_rec.get("model"):
changed_fields["model"] = {
"from": from_rec.get("model"),
"to": to_rec.get("model"),
}
if from_rec.get("analysis_type") != to_rec.get("analysis_type"):
changed_fields["analysis_type"] = {
"from": from_rec.get("analysis_type"),
"to": to_rec.get("analysis_type"),
}
# Build a human-readable summary
parts: list[str] = []
if added:
parts.append(f"{len(added)} new patent(s) appeared")
if removed:
parts.append(f"{len(removed)} patent(s) no longer referenced")
if patent_count_delta > 0:
parts.append(f"patent mention count increased by {patent_count_delta}")
elif patent_count_delta < 0:
parts.append(f"patent mention count decreased by {abs(patent_count_delta)}")
if changed_fields:
parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}")
summary = "; ".join(parts) if parts else "No significant differences detected."
return AnalysisDiffResponse(
company_name=to_rec["company_name"],
from_id=from_rec["id"],
to_id=to_rec["id"],
from_timestamp=from_rec["timestamp"],
to_timestamp=to_rec["timestamp"],
patent_count_delta=patent_count_delta,
added_patents=added,
removed_patents=removed,
changed_fields=changed_fields,
summary=summary,
)
@app.get(
"/analyze/{company_name}/history",
response_model=list[CompanyAnalysisHistoryItem],
tags=["Analysis"],
)
async def list_company_analysis_history(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
limit: int = Query(default=20, ge=1, le=100),
_: UserResponse = Depends(get_current_user),
):
"""List previous analysis runs for a company.
Returns a list of analysis records ordered by timestamp descending,
useful for selecting which runs to compare via the diff endpoint.
Args:
company_name: Company name to look up
limit: Maximum number of results
Returns:
List of analysis history items
"""
db = _get_job_db()
rows = db.list_company_analyses(company_name, limit=limit)
return [
CompanyAnalysisHistoryItem(
id=r["id"],
analysis_type=r.get("analysis_type"),
model=r.get("model"),
timestamp=r["timestamp"],
)
for r in rows
]
@app.get(
"/analyze/{company_name}/diff",
response_model=AnalysisDiffResponse,
tags=["Analysis"],
)
async def diff_company_analyses(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
from_id: int = Query(..., alias="from", description="Analysis ID of the older run"),
to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"),
_: UserResponse = Depends(get_current_user),
):
"""Compare two analysis runs for the same company.
Returns a structured diff showing added/removed patents, score delta,
and a summary narrative.
Args:
company_name: Company name (must match both analysis records)
from_id: ID of the older analysis run
to_id: ID of the newer analysis run
Returns:
AnalysisDiffResponse with added/removed/changed fields
Raises:
404: If either analysis ID does not exist or belongs to a different company
"""
db = _get_job_db()
from_rec = db.get_analysis_by_id(from_id)
if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {from_id} not found for company '{company_name}'",
)
to_rec = db.get_analysis_by_id(to_id)
if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {to_id} not found for company '{company_name}'",
)
return _compute_analysis_diff(from_rec, to_rec)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/batch",
response_model=PaginatedAnalysisResponse,
@@ -1270,9 +1065,9 @@ async def list_analysis_results(
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""List stored analysis results with cursor-based pagination.
"""List stored analysis results with cursor-based pagination (scoped to current user).
Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response
@@ -1289,7 +1084,7 @@ async def list_analysis_results(
Paginated list of analysis results
"""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(rows) > limit
if has_next:
@@ -1339,6 +1134,68 @@ async def analyze_companies_batch(
return _convert_batch_result(result)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns:
Analysis results including patent count, AI insights, and success status
"""
_validate_model(model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
@@ -1425,7 +1282,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
async def analyze_companies_async(
request: BatchAnalysisRequest,
background_tasks: BackgroundTasks,
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Start an asynchronous batch analysis job.
@@ -1445,7 +1302,7 @@ async def analyze_companies_async(
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies), owner_id=current_user.id)
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers, request.model
@@ -1457,9 +1314,9 @@ async def analyze_companies_async(
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
async def get_job_status(
job_id: str,
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""Get the status of a background analysis job.
"""Get the status of a background analysis job (scoped to current user).
Args:
job_id: The job ID returned from the async batch endpoint
@@ -1468,7 +1325,7 @@ async def get_job_status(
Current job status including progress and results when complete
"""
db = _get_job_db()
job_row = db.get_job(job_id)
job_row = db.get_job(job_id, owner_id=current_user.id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
@@ -1487,9 +1344,9 @@ async def list_jobs(
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
current_user: UserResponse = Depends(get_current_user),
):
"""List analysis jobs with cursor-based pagination.
"""List analysis jobs with cursor-based pagination (scoped to current user).
Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page.
@@ -1508,7 +1365,7 @@ async def list_jobs(
"""
db = _get_job_db()
# Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor, owner_id=current_user.id)
has_next = len(job_rows) > limit
if has_next:
+7 -96
View File
@@ -1,13 +1,11 @@
"""JWT and API key authentication utilities for SPARC API."""
"""JWT authentication utilities for SPARC API."""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
import jwt
from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
@@ -34,7 +32,7 @@ def check_jwt_secret() -> None:
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer(auto_error=False)
security = HTTPBearer()
class TokenPayload(BaseModel):
@@ -180,107 +178,20 @@ def get_db_client() -> DatabaseClient:
return _db_client
def generate_api_key() -> str:
"""Generate a random 32-byte hex API key.
Returns:
64-character hex string
"""
return secrets.token_hex(32)
def hash_api_key(key: str) -> str:
"""Hash an API key using bcrypt.
Args:
key: Plaintext API key
Returns:
bcrypt hash string
"""
return bcrypt.hashpw(key.encode(), bcrypt.gensalt()).decode()
def verify_api_key(key: str, key_hash: str) -> bool:
"""Verify a plaintext API key against its bcrypt hash.
Args:
key: Plaintext API key
key_hash: Stored bcrypt hash
Returns:
True if key matches
"""
return bcrypt.checkpw(key.encode(), key_hash.encode())
def _authenticate_via_api_key(api_key: str) -> Optional[UserResponse]:
"""Look up a user by raw API key.
Iterates over all stored key hashes (small table) and returns the
corresponding user when a match is found.
Args:
api_key: Plaintext API key from X-API-Key header
Returns:
UserResponse if valid key, None otherwise
"""
db = get_db_client()
key_rows = db.get_all_api_key_hashes()
for row in key_rows:
if verify_api_key(api_key, row["key_hash"]):
user = db.get_user_by_id(row["user_id"])
if user:
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
return None
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token or API key.
Supports two authentication methods:
1. Bearer JWT token via Authorization header
2. API key via X-API-Key header
"""Get the current authenticated user from JWT token.
Args:
request: The incoming request (used for X-API-Key header)
credentials: Optional Bearer token from request
credentials: Bearer token from request
Returns:
UserResponse with user details
Raises:
HTTPException: If no valid credentials are provided
HTTPException: If token is invalid or expired
"""
# Try X-API-Key header first
api_key = request.headers.get("X-API-Key")
if api_key:
user = _authenticate_via_api_key(api_key)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
# Fall back to JWT Bearer token
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = decode_token(token)
+156 -176
View File
@@ -196,7 +196,7 @@ class DatabaseClient:
cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL,
company_name VARCHAR(255) NOT NULL,
last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -221,25 +221,66 @@ class DatabaseClient:
ON alerts(company_name)
""")
# Create API keys table for programmatic access
# ---- Multi-tenant: add owner_id columns if missing ----
cursor.execute("""
CREATE TABLE IF NOT EXISTS api_keys (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_hash VARCHAR(255) NOT NULL,
label VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
DO $$
BEGIN
-- llm_messages.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- jobs.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
-- tracked_companies.owner_id
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# Indexes for owner_id filtering
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
ON api_keys(user_id)
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# Drop the old unique constraint on company_name alone (if it exists)
# and replace with a per-owner unique constraint so different users
# can track the same company independently.
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash
ON api_keys(key_hash)
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
self.conn.commit()
@@ -310,6 +351,7 @@ class DatabaseClient:
metadata: Optional[Dict] = None,
token_usage: Optional[Dict] = None,
is_cached: bool = False,
owner_id: Optional[int] = None,
) -> int:
"""Store an LLM message exchange in the database.
@@ -322,6 +364,7 @@ class DatabaseClient:
metadata: Additional metadata as dict
token_usage: Token usage information
is_cached: Whether this response was served from cache
owner_id: ID of the user who owns this record
Returns:
The ID of the inserted record
@@ -333,8 +376,8 @@ class DatabaseClient:
cursor.execute(
"""
INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached, owner_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
@@ -347,6 +390,7 @@ class DatabaseClient:
json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None,
is_cached,
owner_id,
),
)
@@ -361,6 +405,7 @@ class DatabaseClient:
analysis_type: Optional[str] = None,
limit: int = 100,
offset: int = 0,
owner_id: Optional[int] = None,
) -> List[Dict]:
"""Retrieve messages from the database.
@@ -369,6 +414,7 @@ class DatabaseClient:
analysis_type: Filter by analysis type
limit: Maximum number of records to return
offset: Number of records to skip
owner_id: Filter by owner (None returns all, for admin use)
Returns:
List of message dictionaries
@@ -376,6 +422,10 @@ class DatabaseClient:
query = "SELECT * FROM llm_messages WHERE 1=1"
params = []
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
if company_name:
query += " AND company_name = %s"
params.append(company_name)
@@ -397,6 +447,7 @@ class DatabaseClient:
company_name: Optional[str] = None,
limit: int = 50,
cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]:
"""List analysis results with cursor-based pagination.
@@ -404,6 +455,7 @@ class DatabaseClient:
company_name: Optional filter by company name.
limit: Maximum number of records to return.
cursor: Opaque cursor (``timestamp|id``) from a previous response.
owner_id: Filter by owner (None returns all, for admin use).
Returns:
List of analysis dicts ordered by timestamp descending.
@@ -411,6 +463,10 @@ class DatabaseClient:
conditions: list[str] = ["is_cached = FALSE"]
params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if company_name:
conditions.append("LOWER(company_name) = LOWER(%s)")
params.append(company_name)
@@ -434,52 +490,62 @@ class DatabaseClient:
cur.execute(query, params)
return [dict(row) for row in cur.fetchall()]
def get_analytics(self, days: int = 30) -> Dict:
def get_analytics(self, days: int = 30, owner_id: Optional[int] = None) -> Dict:
"""Get analytics on message usage.
Args:
days: Number of days to look back
owner_id: Filter by owner (None returns all, for admin use)
Returns:
Dictionary with analytics data
"""
owner_filter = ""
owner_params: list = []
if owner_id is not None:
owner_filter = " AND owner_id = %s"
owner_params = [owner_id]
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
f"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
""",
(days,),
(days, *owner_params),
)
total = cursor.fetchone()["total_messages"]
# Messages by company
cursor.execute(
"""
f"""
SELECT company_name, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY company_name
ORDER BY count DESC
LIMIT 10
""",
(days,),
(days, *owner_params),
)
by_company = cursor.fetchall()
# Messages by type
cursor.execute(
"""
f"""
SELECT analysis_type, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
{owner_filter}
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
(days, *owner_params),
)
by_type = cursor.fetchall()
@@ -577,12 +643,14 @@ class DatabaseClient:
self,
job_id: str,
total_companies: int,
owner_id: Optional[int] = None,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
owner_id: ID of the user who owns this job
Returns:
Job dict
@@ -591,11 +659,11 @@ class DatabaseClient:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies, owner_id)
VALUES (%s, 'pending', 0, %s, 0, %s)
RETURNING *
""",
(job_id, total_companies),
(job_id, total_companies, owner_id),
)
job = cursor.fetchone()
conn.commit()
@@ -648,11 +716,22 @@ class DatabaseClient:
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
def get_job(self, job_id: str, owner_id: Optional[int] = None) -> Optional[Dict]:
"""Get a job by ID.
Args:
job_id: Job identifier.
owner_id: When provided, only return the job if it belongs to this owner.
"""
query = "SELECT * FROM jobs WHERE job_id = %s"
params: list = [job_id]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
cursor.execute(query, params)
job = cursor.fetchone()
return dict(job) if job else None
@@ -661,6 +740,7 @@ class DatabaseClient:
status: Optional[str] = None,
limit: int = 10,
cursor: Optional[str] = None,
owner_id: Optional[int] = None,
) -> List[Dict]:
"""List jobs with optional status filter and cursor-based pagination.
@@ -670,6 +750,7 @@ class DatabaseClient:
cursor: Opaque cursor (``created_at|job_id``) from a previous
response. When provided, only jobs older than the cursor are
returned.
owner_id: Filter by owner (None returns all, for admin use).
Returns:
List of job dicts ordered by created_at descending.
@@ -677,6 +758,10 @@ class DatabaseClient:
conditions: list[str] = []
params: list = []
if owner_id is not None:
conditions.append("owner_id = %s")
params.append(owner_id)
if status:
conditions.append("status = %s")
params.append(status)
@@ -923,14 +1008,21 @@ class DatabaseClient:
# Tracked Companies Methods
def add_tracked_company(self, company_name: str) -> Optional[Dict]:
"""Add a company to the tracking list."""
def add_tracked_company(
self, company_name: str, owner_id: Optional[int] = None
) -> Optional[Dict]:
"""Add a company to the tracking list.
Args:
company_name: Company name to track.
owner_id: ID of the user who owns this tracked company.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try:
cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *",
(company_name,),
"INSERT INTO tracked_companies (company_name, owner_id) VALUES (%s, %s) RETURNING *",
(company_name, owner_id),
)
row = cursor.fetchone()
conn.commit()
@@ -939,22 +1031,45 @@ class DatabaseClient:
conn.rollback()
return None
def remove_tracked_company(self, company_name: str) -> bool:
"""Remove a company from the tracking list."""
def remove_tracked_company(
self, company_name: str, owner_id: Optional[int] = None
) -> bool:
"""Remove a company from the tracking list.
Args:
company_name: Company name to remove.
owner_id: When provided, only remove if owned by this user.
"""
query = "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)"
params: list = [company_name]
if owner_id is not None:
query += " AND owner_id = %s"
params.append(owner_id)
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
cursor.execute(query, params)
conn.commit()
return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]:
"""List all tracked companies."""
def list_tracked_companies(
self, owner_id: Optional[int] = None
) -> List[Dict]:
"""List tracked companies.
Args:
owner_id: Filter by owner (None returns all, for admin/scheduler use).
"""
query = "SELECT * FROM tracked_companies"
params: list = []
if owner_id is not None:
query += " WHERE owner_id = %s"
params.append(owner_id)
query += " ORDER BY company_name"
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name")
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def update_tracked_company(
@@ -998,138 +1113,3 @@ class DatabaseClient:
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
# Historical Analysis Diff Methods
def get_analysis_by_id(self, analysis_id: int) -> Optional[Dict]:
"""Get a single analysis record by its ID.
Args:
analysis_id: The primary key of the llm_messages row.
Returns:
Dict with analysis fields, or None if not found.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE id = %s AND is_cached = FALSE
""",
(analysis_id,),
)
row = cursor.fetchone()
return dict(row) if row else None
def list_company_analyses(
self, company_name: str, limit: int = 20
) -> List[Dict]:
"""List past analysis runs for a given company.
Returns records ordered by timestamp descending so callers can
identify which previous runs are available for diffing.
Args:
company_name: Company name (case-insensitive match).
limit: Maximum number of records.
Returns:
List of analysis dicts.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
LIMIT %s
""",
(company_name, limit),
)
return [dict(row) for row in cursor.fetchall()]
# API Key Methods
def create_api_key(
self,
user_id: int,
key_hash: str,
label: Optional[str] = None,
) -> Dict:
"""Store a new API key hash for a user.
Args:
user_id: The owning user's ID
key_hash: bcrypt hash of the plaintext key
label: Optional human-readable label
Returns:
Dict with id, user_id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO api_keys (user_id, key_hash, label)
VALUES (%s, %s, %s)
RETURNING id, user_id, label, created_at
""",
(user_id, key_hash, label),
)
row = cursor.fetchone()
conn.commit()
return dict(row)
def list_api_keys(self, user_id: int) -> List[Dict]:
"""List active API key metadata for a user (no secrets).
Args:
user_id: The user's ID
Returns:
List of dicts with id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, label, created_at FROM api_keys WHERE user_id = %s ORDER BY created_at DESC",
(user_id,),
)
return [dict(row) for row in cursor.fetchall()]
def delete_api_key(self, key_id: int, user_id: int) -> bool:
"""Revoke an API key by ID (must belong to user).
Args:
key_id: The API key row ID
user_id: The owning user's ID
Returns:
True if a key was deleted
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM api_keys WHERE id = %s AND user_id = %s",
(key_id, user_id),
)
deleted = cursor.rowcount > 0
conn.commit()
return deleted
def get_all_api_key_hashes(self) -> List[Dict]:
"""Return all API key hashes with their associated user IDs.
Used by the auth layer to validate an incoming API key.
Returns:
List of dicts with key_hash, user_id
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT key_hash, user_id FROM api_keys")
return [dict(row) for row in cursor.fetchall()]
-2
View File
@@ -13,7 +13,6 @@ import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare';
import { HistoryDiff } from './pages/HistoryDiff';
const queryClient = new QueryClient({
defaultOptions: {
@@ -47,7 +46,6 @@ function App() {
<Route path="/batch" element={<Batch />} />
<Route path="/analytics" element={<AnalyticsPage />} />
<Route path="/compare" element={<Compare />} />
<Route path="/history-diff" element={<HistoryDiff />} />
<Route path="/about" element={<About />} />
{/* Admin routes */}
-35
View File
@@ -148,43 +148,8 @@ export const analysisApi = {
const response = await api.get<JobStatus[]>(`/jobs?${params}`);
return response.data;
},
getCompanyHistory: async (companyName: string, limit = 20): Promise<AnalysisHistoryItem[]> => {
const response = await api.get<AnalysisHistoryItem[]>(
`/analyze/${encodeURIComponent(companyName)}/history?limit=${limit}`
);
return response.data;
},
diffAnalyses: async (companyName: string, fromId: number, toId: number): Promise<AnalysisDiff> => {
const response = await api.get<AnalysisDiff>(
`/analyze/${encodeURIComponent(companyName)}/diff?from=${fromId}&to=${toId}`
);
return response.data;
},
};
// Analysis diff types
export interface AnalysisHistoryItem {
id: number;
analysis_type: string | null;
model: string | null;
timestamp: string;
}
export interface AnalysisDiff {
company_name: string;
from_id: number;
to_id: number;
from_timestamp: string;
to_timestamp: string;
patent_count_delta: number;
added_patents: string[];
removed_patents: string[];
changed_fields: Record<string, { from: string | null; to: string | null }>;
summary: string;
}
// Export API
export const exportApi = {
exportCsv: async (companyName: string): Promise<void> => {
+1 -2
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, History, ShieldAlert } from 'lucide-react';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react';
export function Layout() {
const { user, isAdmin, logout } = useAuth();
@@ -18,7 +18,6 @@ export function Layout() {
{ to: '/batch', icon: Layers, label: 'Batch' },
{ to: '/analytics', icon: BarChart3, label: 'Analytics' },
{ to: '/compare', icon: GitCompareArrows, label: 'Compare' },
{ to: '/history-diff', icon: History, label: 'Diff' },
{ to: '/about', icon: Info, label: 'About' },
];
+1 -10
View File
@@ -1,12 +1,10 @@
import { useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { useMutation, useQuery } from '@tanstack/react-query';
import { analysisApi, exportApi } from '../api/client';
import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown, History } from 'lucide-react';
import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react';
import type { CompanyAnalysis } from '../types';
export function Analysis() {
const navigate = useNavigate();
const [companyName, setCompanyName] = useState('');
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<CompanyAnalysis | null>(null);
@@ -159,13 +157,6 @@ export function Analysis() {
<FileText size={14} />
Export PDF
</button>
<button
onClick={() => navigate(`/history-diff?company=${encodeURIComponent(result.company_name)}`)}
className="flex items-center gap-2 text-sm bg-secondary/20 hover:bg-secondary/30 text-secondary font-medium px-3 py-1.5 rounded-lg transition-colors"
>
<History size={14} />
Compare with previous
</button>
</div>
</div>
<div className="prose dark:prose-invert max-w-none">
-249
View File
@@ -1,249 +0,0 @@
import { useState } from 'react';
import { useSearchParams } from 'react-router-dom';
import { useQuery } from '@tanstack/react-query';
import { analysisApi } from '../api/client';
import type { AnalysisHistoryItem, AnalysisDiff } from '../api/client';
import { History, ArrowRight, Plus, Minus, AlertCircle, Search } from 'lucide-react';
export function HistoryDiff() {
const [searchParams, setSearchParams] = useSearchParams();
const [companyInput, setCompanyInput] = useState(searchParams.get('company') || '');
const company = searchParams.get('company') || '';
const fromId = searchParams.get('from');
const toId = searchParams.get('to');
// Fetch history when a company is selected
const historyQuery = useQuery({
queryKey: ['history', company],
queryFn: () => analysisApi.getCompanyHistory(company),
enabled: !!company,
});
// Fetch diff when both IDs are selected
const diffQuery = useQuery<AnalysisDiff>({
queryKey: ['diff', company, fromId, toId],
queryFn: () => analysisApi.diffAnalyses(company, Number(fromId), Number(toId)),
enabled: !!company && !!fromId && !!toId,
});
const handleSearch = (e: React.FormEvent) => {
e.preventDefault();
const name = companyInput.trim();
if (name) {
setSearchParams({ company: name });
}
};
const handleSelectRuns = (from: number, to: number) => {
setSearchParams({ company, from: String(from), to: String(to) });
};
const history: AnalysisHistoryItem[] = historyQuery.data || [];
return (
<div className="space-y-6">
{/* Header */}
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Historical Analysis Diff
</h2>
<p className="text-text-secondary">
Compare analysis runs for the same company to see what changed between them.
</p>
</div>
{/* Company Search */}
<form onSubmit={handleSearch} className="flex gap-4">
<div className="flex-1 relative">
<Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} />
<input
type="text"
value={companyInput}
onChange={(e) => setCompanyInput(e.target.value)}
placeholder="Enter company name (e.g., nvidia)"
className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all"
/>
</div>
<button
type="submit"
disabled={!companyInput.trim()}
className="bg-gradient-to-r from-primary to-primary-dark text-white font-semibold py-3 px-6 rounded-xl hover:shadow-lg hover:shadow-primary/30 transition-all disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
<History size={18} />
Load History
</button>
</form>
{/* History list */}
{company && historyQuery.isLoading && (
<div className="text-text-secondary animate-pulse">Loading analysis history...</div>
)}
{company && historyQuery.isError && (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load history. Check the company name and try again.</span>
</div>
)}
{company && history.length === 0 && !historyQuery.isLoading && (
<div className="text-text-secondary">No analysis history found for "{company}".</div>
)}
{history.length >= 2 && (
<div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-4">
Select Two Runs to Compare
</h3>
<div className="space-y-2">
{history.map((item, idx) => {
const next = history[idx + 1];
if (!next) return null;
const isSelected =
fromId === String(next.id) && toId === String(item.id);
return (
<button
key={item.id}
onClick={() => handleSelectRuns(next.id, item.id)}
className={`w-full text-left flex items-center gap-3 px-4 py-3 rounded-xl border transition-all ${
isSelected
? 'border-primary bg-primary/10'
: 'border-primary/15 hover:border-primary/40 hover:bg-primary/5'
}`}
>
<span className="text-sm text-text-secondary font-mono">
#{next.id}
</span>
<span className="text-xs text-text-secondary">
{new Date(next.timestamp).toLocaleString()}
</span>
<ArrowRight size={14} className="text-primary" />
<span className="text-sm text-text-secondary font-mono">
#{item.id}
</span>
<span className="text-xs text-text-secondary">
{new Date(item.timestamp).toLocaleString()}
</span>
{item.model && (
<span className="ml-auto text-xs bg-primary/20 text-primary px-2 py-0.5 rounded">
{item.model}
</span>
)}
</button>
);
})}
</div>
</div>
)}
{/* Diff Results */}
{diffQuery.isLoading && (
<div className="text-text-secondary animate-pulse">Computing diff...</div>
)}
{diffQuery.isError && (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to compute diff. One or both analysis IDs may not exist.</span>
</div>
)}
{diffQuery.data && <DiffView diff={diffQuery.data} />}
</div>
);
}
function DiffView({ diff }: { diff: AnalysisDiff }) {
return (
<div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6 space-y-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2">
Diff: #{diff.from_id} &rarr; #{diff.to_id}
</h3>
{/* Summary */}
<div className="bg-primary/5 border border-primary/20 rounded-xl p-4">
<div className="text-sm font-medium text-text-primary">{diff.summary}</div>
<div className="flex items-center gap-4 mt-2 text-xs text-text-secondary">
<span>{new Date(diff.from_timestamp).toLocaleString()}</span>
<ArrowRight size={12} />
<span>{new Date(diff.to_timestamp).toLocaleString()}</span>
</div>
</div>
{/* Patent count delta */}
<div className="flex items-center gap-3">
<span className="text-sm text-text-secondary">Patent mention delta:</span>
<span
className={`text-lg font-bold ${
diff.patent_count_delta > 0
? 'text-success'
: diff.patent_count_delta < 0
? 'text-error'
: 'text-text-secondary'
}`}
>
{diff.patent_count_delta > 0 ? '+' : ''}
{diff.patent_count_delta}
</span>
</div>
{/* Added patents */}
{diff.added_patents.length > 0 && (
<div>
<h4 className="text-sm font-semibold text-success flex items-center gap-1 mb-2">
<Plus size={14} />
New Patents ({diff.added_patents.length})
</h4>
<div className="flex flex-wrap gap-2">
{diff.added_patents.map((p) => (
<span
key={p}
className="text-xs bg-success/10 border border-success/20 text-success px-2 py-1 rounded font-mono"
>
{p}
</span>
))}
</div>
</div>
)}
{/* Removed patents */}
{diff.removed_patents.length > 0 && (
<div>
<h4 className="text-sm font-semibold text-error flex items-center gap-1 mb-2">
<Minus size={14} />
Removed Patents ({diff.removed_patents.length})
</h4>
<div className="flex flex-wrap gap-2">
{diff.removed_patents.map((p) => (
<span
key={p}
className="text-xs bg-error/10 border border-error/20 text-error px-2 py-1 rounded font-mono"
>
{p}
</span>
))}
</div>
</div>
)}
{/* Changed fields */}
{Object.keys(diff.changed_fields).length > 0 && (
<div>
<h4 className="text-sm font-semibold text-text-primary mb-2">Changed Fields</h4>
<div className="space-y-1">
{Object.entries(diff.changed_fields).map(([field, vals]) => (
<div key={field} className="flex items-center gap-2 text-sm">
<span className="text-text-secondary font-mono">{field}:</span>
<span className="text-error line-through">{vals.from || 'null'}</span>
<ArrowRight size={12} className="text-text-secondary" />
<span className="text-success">{vals.to || 'null'}</span>
</div>
))}
</div>
</div>
)}
</div>
);
}
+132
View File
@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Migration: add owner_id columns and backfill existing rows.
This script adds an ``owner_id`` column (FK to ``users``) to the
``llm_messages``, ``jobs``, and ``tracked_companies`` tables, then
backfills all existing rows with ``owner_id = 1`` (the default admin user).
It also replaces the old global UNIQUE constraint on
``tracked_companies.company_name`` with a per-owner unique index so that
different users can independently track the same company.
Usage:
python scripts/migrate_add_owner_id.py
The script is idempotent running it multiple times is safe.
"""
import os
import sys
import psycopg2
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/sparc",
)
DEFAULT_OWNER_ID = 1
def run_migration():
"""Execute the migration."""
conn = psycopg2.connect(DATABASE_URL)
conn.autocommit = False
try:
with conn.cursor() as cur:
# ---------- 1. Add owner_id columns if missing ----------
cur.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'owner_id'
) THEN
ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'jobs' AND column_name = 'owner_id'
) THEN
ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'tracked_companies' AND column_name = 'owner_id'
) THEN
ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id);
END IF;
END $$;
""")
# ---------- 2. Backfill owner_id = DEFAULT_OWNER_ID ----------
cur.execute(
"UPDATE llm_messages SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
messages_updated = cur.rowcount
print(f" llm_messages: backfilled {messages_updated} rows")
cur.execute(
"UPDATE jobs SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
jobs_updated = cur.rowcount
print(f" jobs: backfilled {jobs_updated} rows")
cur.execute(
"UPDATE tracked_companies SET owner_id = %s WHERE owner_id IS NULL",
(DEFAULT_OWNER_ID,),
)
tracked_updated = cur.rowcount
print(f" tracked_companies: backfilled {tracked_updated} rows")
# ---------- 3. Create indexes ----------
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_owner
ON llm_messages(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_owner
ON jobs(owner_id)
""")
cur.execute("""
CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner
ON tracked_companies(owner_id)
""")
# ---------- 4. Replace unique constraint on tracked_companies ----------
cur.execute("""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_constraint
WHERE conname = 'tracked_companies_company_name_key'
) THEN
ALTER TABLE tracked_companies
DROP CONSTRAINT tracked_companies_company_name_key;
END IF;
END $$;
""")
cur.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner
ON tracked_companies(LOWER(company_name), owner_id)
""")
conn.commit()
print("Migration completed successfully.")
except Exception:
conn.rollback()
print("Migration FAILED — rolled back.", file=sys.stderr)
raise
finally:
conn.close()
if __name__ == "__main__":
print(f"Running owner_id migration against {DATABASE_URL.split('@')[-1]} ...")
run_migration()
-244
View File
@@ -1,244 +0,0 @@
"""Tests for historical analysis diff endpoint."""
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import AnalysisDiffResponse, _compute_analysis_diff, _extract_patent_ids, app
from SPARC.auth import UserResponse, get_current_user
# ---------- helpers ----------
def _mock_user():
"""Return a fake authenticated user for dependency override."""
return UserResponse(
id=1,
email="test@example.com",
role="user",
created_at=datetime(2025, 1, 1),
)
@pytest.fixture
def auth_client():
"""TestClient with auth dependency overridden."""
app.dependency_overrides[get_current_user] = _mock_user
client = TestClient(app, raise_server_exceptions=False)
yield client
app.dependency_overrides.clear()
# ---------- unit tests for helpers ----------
class TestExtractPatentIds:
"""Test _extract_patent_ids utility."""
def test_extracts_standard_ids(self):
text = "Patent US-12345678-B2 covers the device. Also see US-9876543-A1."
ids = _extract_patent_ids(text)
assert "US-12345678-B2" in ids
assert "US-9876543-A1" in ids
def test_empty_text(self):
assert _extract_patent_ids("") == set()
assert _extract_patent_ids(None) == set() # type: ignore[arg-type]
class TestComputeAnalysisDiff:
"""Test _compute_analysis_diff logic."""
def test_identical_analyses(self):
rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 is notable.",
"timestamp": datetime(2025, 5, 1),
}
diff = _compute_analysis_diff(rec, dict(rec, id=2, timestamp=datetime(2025, 5, 2)))
assert diff.patent_count_delta == 0
assert diff.added_patents == []
assert diff.removed_patents == []
def test_added_and_removed_patents(self):
from_rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-11111111-A1.",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 2,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-99999999-B1.",
"timestamp": datetime(2025, 5, 2),
}
diff = _compute_analysis_diff(from_rec, to_rec)
assert "US-99999999-B1" in diff.added_patents
assert "US-11111111-A1" in diff.removed_patents
assert diff.patent_count_delta == 0 # one added, one removed
def test_model_change_detected(self):
from_rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 2,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "anthropic/claude-3.5-sonnet",
"response": "",
"timestamp": datetime(2025, 5, 2),
}
diff = _compute_analysis_diff(from_rec, to_rec)
assert "model" in diff.changed_fields
assert diff.changed_fields["model"]["from"] == "openai/gpt-4o"
assert diff.changed_fields["model"]["to"] == "anthropic/claude-3.5-sonnet"
# ---------- API endpoint tests ----------
class TestDiffEndpoint:
"""Test GET /analyze/{company_name}/diff."""
@patch("SPARC.api._get_job_db")
def test_happy_path(self, mock_get_db, auth_client):
"""Diff returns structured response when both IDs exist."""
db = MagicMock()
mock_get_db.return_value = db
from_rec = {
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 found.",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 20,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-99999999-A1 found.",
"timestamp": datetime(2025, 5, 10),
}
db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else to_rec
response = auth_client.get("/analyze/nvidia/diff?from=10&to=20")
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "nvidia"
assert data["from_id"] == 10
assert data["to_id"] == 20
assert "US-99999999-A1" in data["added_patents"]
assert data["patent_count_delta"] == 1
@patch("SPARC.api._get_job_db")
def test_from_id_not_found(self, mock_get_db, auth_client):
"""Returns 404 when 'from' analysis ID doesn't exist."""
db = MagicMock()
mock_get_db.return_value = db
db.get_analysis_by_id.return_value = None
response = auth_client.get("/analyze/nvidia/diff?from=999&to=1000")
assert response.status_code == 404
assert "999" in response.json()["detail"]
@patch("SPARC.api._get_job_db")
def test_to_id_not_found(self, mock_get_db, auth_client):
"""Returns 404 when 'to' analysis ID doesn't exist."""
db = MagicMock()
mock_get_db.return_value = db
from_rec = {
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else None
response = auth_client.get("/analyze/nvidia/diff?from=10&to=999")
assert response.status_code == 404
assert "999" in response.json()["detail"]
@patch("SPARC.api._get_job_db")
def test_company_mismatch(self, mock_get_db, auth_client):
"""Returns 404 when analysis belongs to a different company."""
db = MagicMock()
mock_get_db.return_value = db
rec = {
"id": 10,
"company_name": "intel",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
db.get_analysis_by_id.return_value = rec
response = auth_client.get("/analyze/nvidia/diff?from=10&to=20")
assert response.status_code == 404
class TestHistoryEndpoint:
"""Test GET /analyze/{company_name}/history."""
@patch("SPARC.api._get_job_db")
def test_returns_history_list(self, mock_get_db, auth_client):
"""History endpoint returns list of past analysis runs."""
db = MagicMock()
mock_get_db.return_value = db
db.list_company_analyses.return_value = [
{
"id": 20,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "...",
"timestamp": datetime(2025, 5, 10),
},
{
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "...",
"timestamp": datetime(2025, 5, 1),
},
]
response = auth_client.get("/analyze/nvidia/history")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["id"] == 20
assert data[1]["id"] == 10
@patch("SPARC.api._get_job_db")
def test_empty_history(self, mock_get_db, auth_client):
"""History endpoint returns empty list when no analyses exist."""
db = MagicMock()
mock_get_db.return_value = db
db.list_company_analyses.return_value = []
response = auth_client.get("/analyze/nvidia/history")
assert response.status_code == 200
assert response.json() == []
+74 -18
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints."""
from datetime import datetime
from unittest.mock import Mock
from datetime import datetime, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -16,6 +17,22 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer."""
@@ -24,6 +41,12 @@ def mock_analyzer(mocker):
return mock
def _auth_header(user_id=1, email="user@test.com", role="user"):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_id, email, role)
return {"Authorization": f"Bearer {token}"}
class TestHealthEndpoint:
"""Test health check endpoint."""
@@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia")
response = client.get("/analyze/nvidia", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown")
response = client.get("/analyze/unknown", headers=_auth_header())
assert response.status_code == 200
data = response.json()
@@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint:
def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation."""
# Empty companies list
response = client.post("/analyze/batch", json={"companies": []})
response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header())
assert response.status_code == 422
# Too many companies
response = client.post(
"/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10},
headers=_auth_header(),
)
assert response.status_code == 422
@@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint:
class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer):
"""Test async endpoint creates a job."""
@patch("SPARC.api._get_job_db")
def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer):
"""Test async endpoint creates a job with owner_id."""
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j1",
"status": "pending",
"progress": 0,
"total_companies": 2,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": 1,
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia", "amd"]},
headers=_auth_header(),
)
assert response.status_code == 200
@@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint:
assert data["status"] == "pending"
assert data["total_companies"] == 2
assert data["progress"] == 0
# Verify owner_id was passed
job_db.create_job.assert_called_once()
assert job_db.create_job.call_args.kwargs.get("owner_id") == 1
class TestJobEndpoints:
"""Test job management endpoints."""
def test_get_job_not_found(self, client):
@patch("SPARC.api._get_job_db")
def test_get_job_not_found(self, mock_get_db, client):
"""Test getting nonexistent job."""
response = client.get("/jobs/nonexistent")
job_db = MagicMock()
job_db.get_job.return_value = None
mock_get_db.return_value = job_db
response = client.get("/jobs/nonexistent", headers=_auth_header())
assert response.status_code == 404
def test_list_jobs(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs(self, mock_get_db, client):
"""Test listing jobs."""
# Clear existing jobs
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs")
response = client.get("/jobs", headers=_auth_header())
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker):
@patch("SPARC.api._get_job_db")
def test_list_jobs_with_filter(self, mock_get_db, client):
"""Test listing jobs with status filter."""
response = client.get("/jobs?status=completed")
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs?status=completed", headers=_auth_header())
assert response.status_code == 200
@@ -189,7 +244,7 @@ class TestModelValidation:
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header())
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
@@ -205,7 +260,7 @@ class TestModelValidation:
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header())
assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
@@ -213,6 +268,7 @@ class TestModelValidation:
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
headers=_auth_header(),
)
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
-319
View File
@@ -1,319 +0,0 @@
"""Tests for user-level API key generation, listing, revocation, and authentication.
Covers all acceptance criteria from issue #1673:
1. Users can create API keys (POST /auth/apikeys)
2. Users can list their active key IDs (GET /auth/apikeys)
3. Users can revoke keys (DELETE /auth/apikeys/{key_id})
4. API requests authenticated with a valid API key work on protected endpoints
5. Revoked keys are immediately rejected
6. Plaintext key is shown only at creation time
All tests use mocked DB fixtures and require no live database.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import (
create_access_token,
generate_api_key,
hash_api_key,
verify_api_key,
)
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user():
return {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth and api endpoints."""
db = MagicMock()
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
db.create_api_key.return_value = None
db.list_api_keys.return_value = []
db.delete_api_key.return_value = False
db.get_all_api_key_hashes.return_value = []
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
class TestCreateApiKey:
"""POST /auth/apikeys"""
def test_create_key_returns_plaintext_and_id(self, client, mock_db):
"""Creating a key returns the plaintext key and metadata."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 42,
"user_id": user["id"],
"label": "my-ci-key",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
json={"label": "my-ci-key"},
headers=_auth_header(user),
)
assert response.status_code == 200
data = response.json()
assert data["id"] == 42
assert len(data["key"]) == 64 # 32 bytes hex = 64 chars
assert data["label"] == "my-ci-key"
assert "created_at" in data
# Verify the hash passed to DB is valid for the returned key
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[1].get("key_hash") or call_args[0][1]
assert verify_api_key(data["key"], stored_hash)
def test_create_key_without_label(self, client, mock_db):
"""Creating a key without a label should work."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 1,
"user_id": user["id"],
"label": None,
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
headers=_auth_header(user),
)
assert response.status_code == 200
assert response.json()["label"] is None
def test_create_key_requires_auth(self, client):
"""Creating a key without auth should fail."""
response = client.post("/auth/apikeys")
assert response.status_code == 401
class TestListApiKeys:
"""GET /auth/apikeys"""
def test_list_keys_returns_metadata_only(self, client, mock_db):
"""Listing keys should return IDs and labels, not secrets."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = [
{"id": 1, "label": "key-1", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc)},
{"id": 2, "label": None, "created_at": datetime(2025, 2, 1, tzinfo=timezone.utc)},
]
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["id"] == 1
assert data[0]["label"] == "key-1"
# Ensure no secret key is exposed
for item in data:
assert "key" not in item
assert "key_hash" not in item
def test_list_keys_empty(self, client, mock_db):
"""User with no keys gets an empty list."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = []
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
assert response.json() == []
class TestRevokeApiKey:
"""DELETE /auth/apikeys/{key_id}"""
def test_revoke_existing_key(self, client, mock_db):
"""Revoking an owned key should succeed."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = True
response = client.delete("/auth/apikeys/42", headers=_auth_header(user))
assert response.status_code == 200
assert "revoked" in response.json()["message"].lower()
mock_db.delete_api_key.assert_called_once_with(42, user["id"])
def test_revoke_nonexistent_key_returns_404(self, client, mock_db):
"""Revoking a key that doesn't exist (or isn't owned) returns 404."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = False
response = client.delete("/auth/apikeys/999", headers=_auth_header(user))
assert response.status_code == 404
class TestApiKeyAuthentication:
"""Using X-API-Key header on protected endpoints."""
def test_valid_api_key_accesses_protected_endpoint(self, client, mock_db):
"""A valid API key should authenticate and access /auth/me."""
user = _make_user()
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": user["id"]},
]
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 200
data = response.json()
assert data["email"] == user["email"]
assert data["id"] == user["id"]
def test_invalid_api_key_returns_401(self, client, mock_db):
"""An invalid API key should return 401."""
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "bad-key"})
assert response.status_code == 401
assert "invalid api key" in response.json()["detail"].lower()
def test_revoked_key_returns_401(self, client, mock_db):
"""After revocation, using the key should return 401."""
# Simulate revoked key: no matching hashes in DB
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "a" * 64})
assert response.status_code == 401
def test_api_key_for_deleted_user_returns_401(self, client, mock_db):
"""An API key whose user no longer exists should return 401."""
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": 999},
]
mock_db.get_user_by_id.return_value = None # user deleted
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 401
def test_no_auth_at_all_returns_401(self, client, mock_db):
"""No auth header at all should return 401."""
response = client.get("/auth/me")
assert response.status_code == 401
class TestApiKeyFullFlow:
"""End-to-end flow: create key, use it, revoke it, try again."""
def test_create_use_revoke_flow(self, client, mock_db):
"""Simulate full lifecycle of an API key."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
# Step 1: Create key
mock_db.create_api_key.return_value = {
"id": 10,
"user_id": user["id"],
"label": "test",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
create_resp = client.post(
"/auth/apikeys",
json={"label": "test"},
headers=_auth_header(user),
)
assert create_resp.status_code == 200
plaintext = create_resp.json()["key"]
# Capture the hash that was stored
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[0][1]
# Step 2: Use key on protected endpoint
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": stored_hash, "user_id": user["id"]},
]
use_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert use_resp.status_code == 200
assert use_resp.json()["email"] == user["email"]
# Step 3: Revoke key
mock_db.delete_api_key.return_value = True
revoke_resp = client.delete("/auth/apikeys/10", headers=_auth_header(user))
assert revoke_resp.status_code == 200
# Step 4: Try using revoked key
mock_db.get_all_api_key_hashes.return_value = [] # key removed from DB
rejected_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert rejected_resp.status_code == 401
class TestApiKeyHelpers:
"""Unit tests for key generation and hashing helpers."""
def test_generate_api_key_length(self):
"""Generated key should be 64 hex characters (32 bytes)."""
key = generate_api_key()
assert len(key) == 64
# Should be valid hex
int(key, 16)
def test_generate_api_key_uniqueness(self):
"""Two generated keys should be different."""
k1 = generate_api_key()
k2 = generate_api_key()
assert k1 != k2
def test_hash_and_verify(self):
"""hash_api_key and verify_api_key should round-trip correctly."""
key = generate_api_key()
hashed = hash_api_key(key)
assert verify_api_key(key, hashed)
assert not verify_api_key("wrong-key", hashed)
+1
View File
@@ -5,6 +5,7 @@ Covers issue #1655:
- GET /export/{company_name}/pdf (PDF export)
All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
Export queries are now scoped to the current user's owner_id.
"""
from datetime import datetime, timezone
+281
View File
@@ -0,0 +1,281 @@
"""Cross-tenant isolation tests for multi-tenant support.
Verifies that:
- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs
- Admin users can access all data via admin endpoints
- owner_id is correctly set on new resources
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user(user_id, email, role="user"):
return {
"id": user_id,
"email": email,
"role": role,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
USER_A = _make_user(10, "alice@test.com")
USER_B = _make_user(20, "bob@test.com")
ADMIN = _make_user(1, "admin@test.com", role="admin")
def _header_for(user):
token = create_access_token(user["id"], user["email"], user["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db():
"""Mock DB returning the correct user based on user_id."""
db = MagicMock()
def _get_user_by_id(uid):
for u in [USER_A, USER_B, ADMIN]:
if u["id"] == uid:
return u
return None
db.get_user_by_id.side_effect = _get_user_by_id
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
# ==================== Tracked Companies Isolation ====================
class TestTrackedCompanyIsolation:
"""User A's tracked companies are invisible to User B."""
def test_user_a_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User A's companies."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": USER_A["id"]},
]
response = client.get("/tracked", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"])
def test_user_b_list_scoped_to_own(self, client, mock_db):
"""GET /tracked returns only User B's companies."""
mock_db.list_tracked_companies.return_value = []
response = client.get("/tracked", headers=_header_for(USER_B))
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"])
def test_user_a_add_sets_owner(self, client, mock_db):
"""POST /tracked sets owner_id to User A."""
mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10}
response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"])
def test_user_b_cannot_remove_user_a_company(self, client, mock_db):
"""DELETE /tracked/{name} filters by owner, so B can't remove A's company."""
mock_db.remove_tracked_company.return_value = False # not found for B
response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B))
assert response.status_code == 404
mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"])
# ==================== Job Isolation ====================
class TestJobIsolation:
"""User A's jobs are invisible to User B."""
def test_user_a_get_own_job(self, client, mock_db):
"""GET /jobs/{id} scoped to User A returns the job."""
mock_db.get_job.return_value = None # mock via _get_job_db
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = {
"job_id": "j1",
"status": "completed",
"progress": 100,
"total_companies": 1,
"completed_companies": 1,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_A))
assert response.status_code == 200
job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"])
def test_user_b_cannot_see_user_a_job(self, client, mock_db):
"""GET /jobs/{id} returns 404 when User B tries to access User A's job."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.get_job.return_value = None # not found for B's owner_id
mock_get_db.return_value = job_db
response = client.get("/jobs/j1", headers=_header_for(USER_B))
assert response.status_code == 404
job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"])
def test_list_jobs_scoped_to_user(self, client, mock_db):
"""GET /jobs filters by owner_id."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/jobs", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
def test_async_job_created_with_owner(self, client, mock_db):
"""POST /analyze/batch/async creates job with current user's owner_id."""
mock_analyzer = MagicMock()
with patch("SPARC.api._analyzer", mock_analyzer), \
patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.create_job.return_value = {
"job_id": "j2",
"status": "pending",
"progress": 0,
"total_companies": 1,
"completed_companies": 0,
"result_json": None,
"error": None,
"owner_id": USER_A["id"],
}
mock_get_db.return_value = job_db
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia"]},
headers=_header_for(USER_A),
)
assert response.status_code == 200
create_kwargs = job_db.create_job.call_args
assert create_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Analysis Listing Isolation ====================
class TestAnalysisListIsolation:
"""GET /analyze/batch scoped to current user."""
def test_list_analyses_scoped_to_user(self, client, mock_db):
"""GET /analyze/batch passes owner_id to db.list_analyses."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/analyze/batch", headers=_header_for(USER_A))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
assert call_kwargs.kwargs.get("owner_id") == USER_A["id"]
# ==================== Admin Cross-Tenant Access ====================
class TestAdminCrossTenantAccess:
"""Admin endpoints return data from all tenants (no owner_id filter)."""
def test_admin_list_tracked_all_tenants(self, client, mock_db):
"""GET /admin/tracked returns all companies (no owner_id filter)."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "AliceCo", "owner_id": 10},
{"company_name": "BobCo", "owner_id": 20},
]
response = client.get("/admin/tracked", headers=_header_for(ADMIN))
assert response.status_code == 200
# Should be called without owner_id filter
mock_db.list_tracked_companies.assert_called_with()
def test_admin_list_analyses_all_tenants(self, client, mock_db):
"""GET /admin/analyses returns all analyses (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_analyses.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/analyses", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_analyses.call_args
# No owner_id should be passed
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_list_jobs_all_tenants(self, client, mock_db):
"""GET /admin/jobs returns all jobs (no owner_id filter)."""
with patch("SPARC.api._get_job_db") as mock_get_db:
job_db = MagicMock()
job_db.list_jobs.return_value = []
mock_get_db.return_value = job_db
response = client.get("/admin/jobs", headers=_header_for(ADMIN))
assert response.status_code == 200
call_kwargs = job_db.list_jobs.call_args
assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None
def test_admin_remove_tracked_any_owner(self, client, mock_db):
"""DELETE /admin/tracked/{name} removes without owner filter."""
mock_db.remove_tracked_company.return_value = True
response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN))
assert response.status_code == 200
# Called without owner_id
mock_db.remove_tracked_company.assert_called_with("SomeCo")
def test_regular_user_cannot_access_admin_analyses(self, client, mock_db):
"""Regular user gets 403 on /admin/analyses."""
response = client.get("/admin/analyses", headers=_header_for(USER_A))
assert response.status_code == 403
def test_regular_user_cannot_access_admin_jobs(self, client, mock_db):
"""Regular user gets 403 on /admin/jobs."""
response = client.get("/admin/jobs", headers=_header_for(USER_A))
assert response.status_code == 403
# ==================== Analytics Isolation ====================
class TestAnalyticsIsolation:
"""GET /analytics scoped to current user."""
def test_analytics_scoped_to_user(self, client, mock_db):
"""GET /analytics passes owner_id to db.get_analytics."""
mock_db.get_analytics.return_value = {
"total_messages": 5,
"by_company": [],
"by_type": [],
"period_days": 30,
}
response = client.get("/analytics", headers=_header_for(USER_A))
assert response.status_code == 200
mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"])
+35 -13
View File
@@ -1,12 +1,13 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
@@ -15,6 +16,27 @@ def client():
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by auth endpoints."""
db = MagicMock()
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
@@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination:
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10")
response = client.get("/analyze/batch?limit=10", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
@@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3")
response = client.get("/analyze/batch?limit=3", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
@@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42")
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header())
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
@@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch")
client.get("/analyze/batch", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201")
response = client.get("/analyze/batch?limit=201", headers=_auth_header())
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0")
response = client.get("/analyze/batch?limit=0", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel")
client.get("/analyze/batch?company_name=intel", headers=_auth_header())
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
@@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination:
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch")
response = client.get("/analyze/batch", headers=_auth_header())
assert response.status_code == 200
data = response.json()
assert data["items"] == []
@@ -148,14 +170,14 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs")
client.get("/jobs", headers=_auth_header())
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201")
response = client.get("/jobs?limit=201", headers=_auth_header())
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
@@ -165,5 +187,5 @@ class TestJobsPaginationDefaults:
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200")
response = client.get("/jobs?limit=200", headers=_auth_header())
assert response.status_code == 200
+71 -10
View File
@@ -1,17 +1,18 @@
"""Tests for tracked company admin endpoints and scheduler integration.
"""Tests for tracked company endpoints and scheduler integration.
Covers issue #1656:
- GET /admin/tracked (list tracked companies)
- POST /admin/tracked (add a tracked company)
- DELETE /admin/tracked/{company_name} (remove a tracked company)
Covers:
- GET /tracked (user-scoped list)
- POST /tracked (user-scoped add)
- DELETE /tracked/{company_name} (user-scoped remove)
- GET /admin/tracked (admin: all companies)
- POST /admin/tracked (admin: add)
- DELETE /admin/tracked/{company_name} (admin: remove any)
- GET /admin/alerts (list alerts)
- scheduler.run_scheduled_analysis() integration
All tests mock the database layer and use JWT auth fixtures.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch, call
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
@@ -125,7 +126,7 @@ class TestAddTrackedCompany:
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "Intel"
mock_db.add_tracked_company.assert_called_once_with("Intel")
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1)
def test_add_duplicate_returns_409(self, client, mock_db):
"""Adding an already-tracked company returns 409."""
@@ -141,7 +142,7 @@ class TestAddTrackedCompany:
assert "already tracked" in response.json()["detail"].lower()
def test_add_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot add tracked companies."""
"""Regular user cannot add tracked companies via admin endpoint."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
@@ -215,6 +216,66 @@ class TestRemoveTrackedCompany:
assert response.status_code == 403
# ---------- User-scoped tracked companies ----------
class TestUserScopedTrackedCompanies:
"""Tests for /tracked user-scoped endpoints."""
def test_user_list_tracked(self, client, mock_db):
"""Regular user can list their own tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.list_tracked_companies.return_value = [
{"company_name": "AMD", "owner_id": 2},
]
response = client.get("/tracked", headers=_user_header())
assert response.status_code == 200
mock_db.list_tracked_companies.assert_called_with(owner_id=2)
def test_user_add_tracked(self, client, mock_db):
"""Regular user can add a company to their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.add_tracked_company.return_value = {
"company_name": "Intel",
"owner_id": 2,
}
response = client.post(
"/tracked",
json={"company_name": "Intel"},
headers=_user_header(),
)
assert response.status_code == 200
mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2)
def test_user_remove_tracked(self, client, mock_db):
"""Regular user can remove a company from their own tracked list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
mock_db.remove_tracked_company.return_value = True
response = client.delete("/tracked/Intel", headers=_user_header())
assert response.status_code == 200
mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2)
# ---------- GET /admin/alerts ----------
class TestListAlerts: