From 3d8922366ee12e5870a56b35e21d71afbcc5cc08 Mon Sep 17 00:00:00 2001 From: agent-company Date: Tue, 19 May 2026 15:18:34 +0000 Subject: [PATCH 1/3] Add user-level API key generation for programmatic access - Add api_keys table (id, user_id, key_hash, label, created_at) to schema - Add POST /auth/apikeys to generate 32-byte hex API keys (bcrypt-hashed) - Add GET /auth/apikeys to list active key metadata (no secrets) - Add DELETE /auth/apikeys/{key_id} to revoke keys - Extend get_current_user to accept either JWT Bearer or X-API-Key header - Plaintext key returned only at creation time - 16 new tests covering creation, listing, revocation, auth, and full flow Closes leeworks-agents/SPARC#1673 Co-Authored-By: Claude Opus 4.6 --- SPARC/api.py | 304 ++++++++++++++++++++++++++++++++++++++- SPARC/auth.py | 107 ++++++++++++-- SPARC/database.py | 156 ++++++++++++++++++++ tests/test_api_keys.py | 319 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 873 insertions(+), 13 deletions(-) create mode 100644 tests/test_api_keys.py diff --git a/SPARC/api.py b/SPARC/api.py index 1b29d38..ea0674a 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios. from __future__ import annotations +from collections import deque from contextlib import asynccontextmanager -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Annotated, List if TYPE_CHECKING: @@ -29,9 +30,11 @@ from SPARC.auth import ( close_db_client, create_tokens, decode_token, + generate_api_key, get_current_admin, get_current_user, get_db_client, + hash_api_key, init_db_client, ) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -139,6 +142,31 @@ class HealthResponse(BaseModel): timestamp: datetime +# Historical diff models +class AnalysisDiffResponse(BaseModel): + """Response model for diffing two analysis runs of the same company.""" + + company_name: str + from_id: int + to_id: int + from_timestamp: datetime + to_timestamp: datetime + patent_count_delta: int + added_patents: list[str] + removed_patents: list[str] + changed_fields: dict[str, dict] + summary: str + + +class CompanyAnalysisHistoryItem(BaseModel): + """A summary item from a company's analysis history.""" + + id: int + analysis_type: str | None = None + model: str | None = None + timestamp: datetime + + # Auth request/response models class RegisterRequest(BaseModel): """User registration request.""" @@ -248,6 +276,9 @@ app.state.limiter = limiter # In-memory rate limit statistics _rate_limit_stats: dict[str, dict] = {} +# Time-series log of rejected requests (capped to last 24 h worth of entries). +_rejected_log: deque[dict] = deque(maxlen=100_000) + def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: """Record a request against a rate-limited endpoint.""" @@ -262,6 +293,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> _rate_limit_stats[key]["total_requests"] += 1 if rejected: _rate_limit_stats[key]["rejected_requests"] += 1 + _rejected_log.append({ + "endpoint": endpoint, + "ip": ip, + "timestamp": datetime.now(timezone.utc).isoformat(), + }) ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) if ip not in ip_stats: ip_stats[ip] = {"total": 0, "rejected": 0} @@ -378,6 +414,92 @@ async def get_me(current_user: UserResponse = Depends(get_current_user)): return current_user +# ============== API Key Endpoints ============== + + +class CreateApiKeyRequest(BaseModel): + """Request to create a new API key.""" + + label: str | None = Field(default=None, max_length=100, description="Optional label for the key") + + +class ApiKeyResponse(BaseModel): + """Response after creating an API key (includes plaintext key).""" + + id: int + key: str # plaintext key, shown only at creation time + label: str | None = None + created_at: datetime + + +class ApiKeyInfo(BaseModel): + """API key metadata (no secret).""" + + id: int + label: str | None = None + created_at: datetime + + +@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"]) +async def create_api_key_endpoint( + body: CreateApiKeyRequest | None = None, + current_user: UserResponse = Depends(get_current_user), +): + """Generate a new API key for the authenticated user. + + The plaintext key is returned **only once** in the response. + Store it securely; it cannot be retrieved again. + """ + plaintext_key = generate_api_key() + key_hash = hash_api_key(plaintext_key) + + db = get_db_client() + label = body.label if body else None + row = db.create_api_key( + user_id=current_user.id, + key_hash=key_hash, + label=label, + ) + + return ApiKeyResponse( + id=row["id"], + key=plaintext_key, + label=row["label"], + created_at=row["created_at"], + ) + + +@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"]) +async def list_api_keys_endpoint( + current_user: UserResponse = Depends(get_current_user), +): + """List active API key IDs and labels for the authenticated user. + + Does **not** return the secret keys. + """ + db = get_db_client() + keys = db.list_api_keys(current_user.id) + return [ApiKeyInfo(**k) for k in keys] + + +@app.delete("/auth/apikeys/{key_id}", tags=["Auth"]) +async def revoke_api_key_endpoint( + key_id: int, + current_user: UserResponse = Depends(get_current_user), +): + """Revoke (delete) an API key by its ID. + + The key must belong to the authenticated user. + """ + db = get_db_client() + deleted = db.delete_api_key(key_id, current_user.id) + + if not deleted: + raise HTTPException(status_code=404, detail="API key not found") + + return {"message": "API key revoked"} + + # ============== Admin Endpoints ============== @@ -507,10 +629,12 @@ async def get_rate_limit_stats( """Get rate limit status and usage statistics (admin only). Returns current rate limit configuration and request statistics - for all rate-limited endpoints. + for all rate-limited endpoints, including per-IP breakdown and + a time-series of throttled (rejected) requests in the last 24 hours. Returns: - List of rate limit stats per endpoint with total/rejected counts + Rate limit stats per endpoint, per-IP breakdown, and throttled + request history bucketed by hour. """ rate_limits_config = { "/auth/register": {"limit": "5/minute"}, @@ -520,14 +644,45 @@ async def get_rate_limit_stats( results = [] for endpoint, conf in rate_limits_config.items(): stats = _rate_limit_stats.get(endpoint, {}) + by_ip_raw = stats.get("by_ip", {}) + by_ip = [ + {"ip": ip, "total": counts["total"], "rejected": counts["rejected"]} + for ip, counts in by_ip_raw.items() + ] results.append({ "endpoint": endpoint, "limit": conf["limit"], "total_requests": stats.get("total_requests", 0), "rejected_requests": stats.get("rejected_requests", 0), + "by_ip": by_ip, }) - return {"rate_limits": results} + # Build hourly buckets of throttled requests for the last 24 hours + now = datetime.now(timezone.utc) + cutoff = now - timedelta(hours=24) + hourly_buckets: dict[str, int] = {} + throttled_24h = 0 + for entry in _rejected_log: + ts_str = entry["timestamp"] + try: + ts = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + continue + if ts >= cutoff: + throttled_24h += 1 + bucket = ts.strftime("%Y-%m-%dT%H:00:00Z") + hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1 + + throttled_over_time = [ + {"timestamp": k, "count": v} + for k, v in sorted(hourly_buckets.items()) + ] + + return { + "rate_limits": results, + "throttled_24h": throttled_24h, + "throttled_over_time": throttled_over_time, + } @app.get("/admin/alerts", tags=["Admin"]) @@ -927,6 +1082,147 @@ async def analyze_company( return _convert_result(result) +def _extract_patent_ids(response_text: str) -> set[str]: + """Extract patent IDs from an analysis response text. + + Looks for patterns like US-12345678-B2, US12345678B2, etc. + """ + import re + pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?" + return set(re.findall(pattern, response_text or "")) + + +def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse: + """Compute a structured diff between two analysis records.""" + from_patents = _extract_patent_ids(from_rec.get("response", "") or "") + to_patents = _extract_patent_ids(to_rec.get("response", "") or "") + + added = sorted(to_patents - from_patents) + removed = sorted(from_patents - to_patents) + + patent_count_delta = len(to_patents) - len(from_patents) + + changed_fields: dict[str, dict] = {} + if from_rec.get("model") != to_rec.get("model"): + changed_fields["model"] = { + "from": from_rec.get("model"), + "to": to_rec.get("model"), + } + if from_rec.get("analysis_type") != to_rec.get("analysis_type"): + changed_fields["analysis_type"] = { + "from": from_rec.get("analysis_type"), + "to": to_rec.get("analysis_type"), + } + + # Build a human-readable summary + parts: list[str] = [] + if added: + parts.append(f"{len(added)} new patent(s) appeared") + if removed: + parts.append(f"{len(removed)} patent(s) no longer referenced") + if patent_count_delta > 0: + parts.append(f"patent mention count increased by {patent_count_delta}") + elif patent_count_delta < 0: + parts.append(f"patent mention count decreased by {abs(patent_count_delta)}") + if changed_fields: + parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}") + summary = "; ".join(parts) if parts else "No significant differences detected." + + return AnalysisDiffResponse( + company_name=to_rec["company_name"], + from_id=from_rec["id"], + to_id=to_rec["id"], + from_timestamp=from_rec["timestamp"], + to_timestamp=to_rec["timestamp"], + patent_count_delta=patent_count_delta, + added_patents=added, + removed_patents=removed, + changed_fields=changed_fields, + summary=summary, + ) + + +@app.get( + "/analyze/{company_name}/history", + response_model=list[CompanyAnalysisHistoryItem], + tags=["Analysis"], +) +async def list_company_analysis_history( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + limit: int = Query(default=20, ge=1, le=100), + _: UserResponse = Depends(get_current_user), +): + """List previous analysis runs for a company. + + Returns a list of analysis records ordered by timestamp descending, + useful for selecting which runs to compare via the diff endpoint. + + Args: + company_name: Company name to look up + limit: Maximum number of results + + Returns: + List of analysis history items + """ + db = _get_job_db() + rows = db.list_company_analyses(company_name, limit=limit) + return [ + CompanyAnalysisHistoryItem( + id=r["id"], + analysis_type=r.get("analysis_type"), + model=r.get("model"), + timestamp=r["timestamp"], + ) + for r in rows + ] + + +@app.get( + "/analyze/{company_name}/diff", + response_model=AnalysisDiffResponse, + tags=["Analysis"], +) +async def diff_company_analyses( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + from_id: int = Query(..., alias="from", description="Analysis ID of the older run"), + to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"), + _: UserResponse = Depends(get_current_user), +): + """Compare two analysis runs for the same company. + + Returns a structured diff showing added/removed patents, score delta, + and a summary narrative. + + Args: + company_name: Company name (must match both analysis records) + from_id: ID of the older analysis run + to_id: ID of the newer analysis run + + Returns: + AnalysisDiffResponse with added/removed/changed fields + + Raises: + 404: If either analysis ID does not exist or belongs to a different company + """ + db = _get_job_db() + + from_rec = db.get_analysis_by_id(from_id) + if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower(): + raise HTTPException( + status_code=404, + detail=f"Analysis ID {from_id} not found for company '{company_name}'", + ) + + to_rec = db.get_analysis_by_id(to_id) + if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower(): + raise HTTPException( + status_code=404, + detail=f"Analysis ID {to_id} not found for company '{company_name}'", + ) + + return _compute_analysis_diff(from_rec, to_rec) + + @app.get( "/analyze/patent/{patent_id}", tags=["Analysis"], diff --git a/SPARC/auth.py b/SPARC/auth.py index 890d286..932ae53 100644 --- a/SPARC/auth.py +++ b/SPARC/auth.py @@ -1,11 +1,13 @@ -"""JWT authentication utilities for SPARC API.""" +"""JWT and API key authentication utilities for SPARC API.""" import os +import secrets from datetime import datetime, timedelta, timezone from typing import Optional +import bcrypt import jwt -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel @@ -32,7 +34,7 @@ def check_jwt_secret() -> None: "Set a secure JWT_SECRET environment variable before running in non-development environments." ) -security = HTTPBearer() +security = HTTPBearer(auto_error=False) class TokenPayload(BaseModel): @@ -178,20 +180,107 @@ def get_db_client() -> DatabaseClient: return _db_client -async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), -) -> UserResponse: - """Get the current authenticated user from JWT token. +def generate_api_key() -> str: + """Generate a random 32-byte hex API key. + + Returns: + 64-character hex string + """ + return secrets.token_hex(32) + + +def hash_api_key(key: str) -> str: + """Hash an API key using bcrypt. Args: - credentials: Bearer token from request + key: Plaintext API key + + Returns: + bcrypt hash string + """ + return bcrypt.hashpw(key.encode(), bcrypt.gensalt()).decode() + + +def verify_api_key(key: str, key_hash: str) -> bool: + """Verify a plaintext API key against its bcrypt hash. + + Args: + key: Plaintext API key + key_hash: Stored bcrypt hash + + Returns: + True if key matches + """ + return bcrypt.checkpw(key.encode(), key_hash.encode()) + + +def _authenticate_via_api_key(api_key: str) -> Optional[UserResponse]: + """Look up a user by raw API key. + + Iterates over all stored key hashes (small table) and returns the + corresponding user when a match is found. + + Args: + api_key: Plaintext API key from X-API-Key header + + Returns: + UserResponse if valid key, None otherwise + """ + db = get_db_client() + key_rows = db.get_all_api_key_hashes() + + for row in key_rows: + if verify_api_key(api_key, row["key_hash"]): + user = db.get_user_by_id(row["user_id"]) + if user: + return UserResponse( + id=user["id"], + email=user["email"], + role=user["role"], + created_at=user["created_at"], + ) + return None + + +async def get_current_user( + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), +) -> UserResponse: + """Get the current authenticated user from JWT token or API key. + + Supports two authentication methods: + 1. Bearer JWT token via Authorization header + 2. API key via X-API-Key header + + Args: + request: The incoming request (used for X-API-Key header) + credentials: Optional Bearer token from request Returns: UserResponse with user details Raises: - HTTPException: If token is invalid or expired + HTTPException: If no valid credentials are provided """ + # Try X-API-Key header first + api_key = request.headers.get("X-API-Key") + if api_key: + user = _authenticate_via_api_key(api_key) + if user: + return user + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + # Fall back to JWT Bearer token + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = credentials.credentials payload = decode_token(token) diff --git a/SPARC/database.py b/SPARC/database.py index 0759a66..45e87e6 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -221,6 +221,27 @@ class DatabaseClient: ON alerts(company_name) """) + # Create API keys table for programmatic access + cursor.execute(""" + CREATE TABLE IF NOT EXISTS api_keys ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + key_hash VARCHAR(255) NOT NULL, + label VARCHAR(100), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_keys_user_id + ON api_keys(user_id) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash + ON api_keys(key_hash) + """) + self.conn.commit() @staticmethod @@ -977,3 +998,138 @@ class DatabaseClient: (limit,), ) return [dict(row) for row in cursor.fetchall()] + + # Historical Analysis Diff Methods + + def get_analysis_by_id(self, analysis_id: int) -> Optional[Dict]: + """Get a single analysis record by its ID. + + Args: + analysis_id: The primary key of the llm_messages row. + + Returns: + Dict with analysis fields, or None if not found. + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT id, company_name, analysis_type, model, response, timestamp + FROM llm_messages + WHERE id = %s AND is_cached = FALSE + """, + (analysis_id,), + ) + row = cursor.fetchone() + return dict(row) if row else None + + def list_company_analyses( + self, company_name: str, limit: int = 20 + ) -> List[Dict]: + """List past analysis runs for a given company. + + Returns records ordered by timestamp descending so callers can + identify which previous runs are available for diffing. + + Args: + company_name: Company name (case-insensitive match). + limit: Maximum number of records. + + Returns: + List of analysis dicts. + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT id, company_name, analysis_type, model, response, timestamp + FROM llm_messages + WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + ORDER BY timestamp DESC + LIMIT %s + """, + (company_name, limit), + ) + return [dict(row) for row in cursor.fetchall()] + + # API Key Methods + + def create_api_key( + self, + user_id: int, + key_hash: str, + label: Optional[str] = None, + ) -> Dict: + """Store a new API key hash for a user. + + Args: + user_id: The owning user's ID + key_hash: bcrypt hash of the plaintext key + label: Optional human-readable label + + Returns: + Dict with id, user_id, label, created_at + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + INSERT INTO api_keys (user_id, key_hash, label) + VALUES (%s, %s, %s) + RETURNING id, user_id, label, created_at + """, + (user_id, key_hash, label), + ) + row = cursor.fetchone() + conn.commit() + return dict(row) + + def list_api_keys(self, user_id: int) -> List[Dict]: + """List active API key metadata for a user (no secrets). + + Args: + user_id: The user's ID + + Returns: + List of dicts with id, label, created_at + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT id, label, created_at FROM api_keys WHERE user_id = %s ORDER BY created_at DESC", + (user_id,), + ) + return [dict(row) for row in cursor.fetchall()] + + def delete_api_key(self, key_id: int, user_id: int) -> bool: + """Revoke an API key by ID (must belong to user). + + Args: + key_id: The API key row ID + user_id: The owning user's ID + + Returns: + True if a key was deleted + """ + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + "DELETE FROM api_keys WHERE id = %s AND user_id = %s", + (key_id, user_id), + ) + deleted = cursor.rowcount > 0 + conn.commit() + return deleted + + def get_all_api_key_hashes(self) -> List[Dict]: + """Return all API key hashes with their associated user IDs. + + Used by the auth layer to validate an incoming API key. + + Returns: + List of dicts with key_hash, user_id + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute("SELECT key_hash, user_id FROM api_keys") + return [dict(row) for row in cursor.fetchall()] diff --git a/tests/test_api_keys.py b/tests/test_api_keys.py new file mode 100644 index 0000000..3942f63 --- /dev/null +++ b/tests/test_api_keys.py @@ -0,0 +1,319 @@ +"""Tests for user-level API key generation, listing, revocation, and authentication. + +Covers all acceptance criteria from issue #1673: +1. Users can create API keys (POST /auth/apikeys) +2. Users can list their active key IDs (GET /auth/apikeys) +3. Users can revoke keys (DELETE /auth/apikeys/{key_id}) +4. API requests authenticated with a valid API key work on protected endpoints +5. Revoked keys are immediately rejected +6. Plaintext key is shown only at creation time + +All tests use mocked DB fixtures and require no live database. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import ( + create_access_token, + generate_api_key, + hash_api_key, + verify_api_key, +) + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +def _make_user(): + return { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + +def _auth_header(user_dict): + """Create an Authorization header with a valid access token.""" + token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"]) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture(autouse=True) +def mock_db(monkeypatch): + """Mock the database client used by auth and api endpoints.""" + db = MagicMock() + db.get_user_count.return_value = 0 + db.get_user_by_id.return_value = None + db.get_user_by_email.return_value = None + db.authenticate_user.return_value = None + db.create_user.return_value = None + db.get_all_users.return_value = [] + db.update_user_role.return_value = None + db.delete_user.return_value = False + db.create_api_key.return_value = None + db.list_api_keys.return_value = [] + db.delete_api_key.return_value = False + db.get_all_api_key_hashes.return_value = [] + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +class TestCreateApiKey: + """POST /auth/apikeys""" + + def test_create_key_returns_plaintext_and_id(self, client, mock_db): + """Creating a key returns the plaintext key and metadata.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.create_api_key.return_value = { + "id": 42, + "user_id": user["id"], + "label": "my-ci-key", + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/auth/apikeys", + json={"label": "my-ci-key"}, + headers=_auth_header(user), + ) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == 42 + assert len(data["key"]) == 64 # 32 bytes hex = 64 chars + assert data["label"] == "my-ci-key" + assert "created_at" in data + + # Verify the hash passed to DB is valid for the returned key + call_args = mock_db.create_api_key.call_args + stored_hash = call_args.kwargs.get("key_hash") or call_args[1].get("key_hash") or call_args[0][1] + assert verify_api_key(data["key"], stored_hash) + + def test_create_key_without_label(self, client, mock_db): + """Creating a key without a label should work.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.create_api_key.return_value = { + "id": 1, + "user_id": user["id"], + "label": None, + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/auth/apikeys", + headers=_auth_header(user), + ) + + assert response.status_code == 200 + assert response.json()["label"] is None + + def test_create_key_requires_auth(self, client): + """Creating a key without auth should fail.""" + response = client.post("/auth/apikeys") + assert response.status_code == 401 + + +class TestListApiKeys: + """GET /auth/apikeys""" + + def test_list_keys_returns_metadata_only(self, client, mock_db): + """Listing keys should return IDs and labels, not secrets.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.list_api_keys.return_value = [ + {"id": 1, "label": "key-1", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc)}, + {"id": 2, "label": None, "created_at": datetime(2025, 2, 1, tzinfo=timezone.utc)}, + ] + + response = client.get("/auth/apikeys", headers=_auth_header(user)) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["id"] == 1 + assert data[0]["label"] == "key-1" + # Ensure no secret key is exposed + for item in data: + assert "key" not in item + assert "key_hash" not in item + + def test_list_keys_empty(self, client, mock_db): + """User with no keys gets an empty list.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.list_api_keys.return_value = [] + + response = client.get("/auth/apikeys", headers=_auth_header(user)) + + assert response.status_code == 200 + assert response.json() == [] + + +class TestRevokeApiKey: + """DELETE /auth/apikeys/{key_id}""" + + def test_revoke_existing_key(self, client, mock_db): + """Revoking an owned key should succeed.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.delete_api_key.return_value = True + + response = client.delete("/auth/apikeys/42", headers=_auth_header(user)) + + assert response.status_code == 200 + assert "revoked" in response.json()["message"].lower() + mock_db.delete_api_key.assert_called_once_with(42, user["id"]) + + def test_revoke_nonexistent_key_returns_404(self, client, mock_db): + """Revoking a key that doesn't exist (or isn't owned) returns 404.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.delete_api_key.return_value = False + + response = client.delete("/auth/apikeys/999", headers=_auth_header(user)) + + assert response.status_code == 404 + + +class TestApiKeyAuthentication: + """Using X-API-Key header on protected endpoints.""" + + def test_valid_api_key_accesses_protected_endpoint(self, client, mock_db): + """A valid API key should authenticate and access /auth/me.""" + user = _make_user() + plaintext = generate_api_key() + hashed = hash_api_key(plaintext) + + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": hashed, "user_id": user["id"]}, + ] + mock_db.get_user_by_id.return_value = user + + response = client.get("/auth/me", headers={"X-API-Key": plaintext}) + + assert response.status_code == 200 + data = response.json() + assert data["email"] == user["email"] + assert data["id"] == user["id"] + + def test_invalid_api_key_returns_401(self, client, mock_db): + """An invalid API key should return 401.""" + mock_db.get_all_api_key_hashes.return_value = [] + + response = client.get("/auth/me", headers={"X-API-Key": "bad-key"}) + + assert response.status_code == 401 + assert "invalid api key" in response.json()["detail"].lower() + + def test_revoked_key_returns_401(self, client, mock_db): + """After revocation, using the key should return 401.""" + # Simulate revoked key: no matching hashes in DB + mock_db.get_all_api_key_hashes.return_value = [] + + response = client.get("/auth/me", headers={"X-API-Key": "a" * 64}) + + assert response.status_code == 401 + + def test_api_key_for_deleted_user_returns_401(self, client, mock_db): + """An API key whose user no longer exists should return 401.""" + plaintext = generate_api_key() + hashed = hash_api_key(plaintext) + + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": hashed, "user_id": 999}, + ] + mock_db.get_user_by_id.return_value = None # user deleted + + response = client.get("/auth/me", headers={"X-API-Key": plaintext}) + + assert response.status_code == 401 + + def test_no_auth_at_all_returns_401(self, client, mock_db): + """No auth header at all should return 401.""" + response = client.get("/auth/me") + assert response.status_code == 401 + + +class TestApiKeyFullFlow: + """End-to-end flow: create key, use it, revoke it, try again.""" + + def test_create_use_revoke_flow(self, client, mock_db): + """Simulate full lifecycle of an API key.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + + # Step 1: Create key + mock_db.create_api_key.return_value = { + "id": 10, + "user_id": user["id"], + "label": "test", + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + create_resp = client.post( + "/auth/apikeys", + json={"label": "test"}, + headers=_auth_header(user), + ) + assert create_resp.status_code == 200 + plaintext = create_resp.json()["key"] + + # Capture the hash that was stored + call_args = mock_db.create_api_key.call_args + stored_hash = call_args.kwargs.get("key_hash") or call_args[0][1] + + # Step 2: Use key on protected endpoint + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": stored_hash, "user_id": user["id"]}, + ] + + use_resp = client.get("/auth/me", headers={"X-API-Key": plaintext}) + assert use_resp.status_code == 200 + assert use_resp.json()["email"] == user["email"] + + # Step 3: Revoke key + mock_db.delete_api_key.return_value = True + revoke_resp = client.delete("/auth/apikeys/10", headers=_auth_header(user)) + assert revoke_resp.status_code == 200 + + # Step 4: Try using revoked key + mock_db.get_all_api_key_hashes.return_value = [] # key removed from DB + rejected_resp = client.get("/auth/me", headers={"X-API-Key": plaintext}) + assert rejected_resp.status_code == 401 + + +class TestApiKeyHelpers: + """Unit tests for key generation and hashing helpers.""" + + def test_generate_api_key_length(self): + """Generated key should be 64 hex characters (32 bytes).""" + key = generate_api_key() + assert len(key) == 64 + # Should be valid hex + int(key, 16) + + def test_generate_api_key_uniqueness(self): + """Two generated keys should be different.""" + k1 = generate_api_key() + k2 = generate_api_key() + assert k1 != k2 + + def test_hash_and_verify(self): + """hash_api_key and verify_api_key should round-trip correctly.""" + key = generate_api_key() + hashed = hash_api_key(key) + assert verify_api_key(key, hashed) + assert not verify_api_key("wrong-key", hashed) -- 2.52.0 From e9ad97d1e8bdd84d48f1621b1f1cb4afc47705b2 Mon Sep 17 00:00:00 2001 From: agent-company Date: Tue, 19 May 2026 15:30:23 +0000 Subject: [PATCH 2/3] Add rate limiting dashboard to admin panel - Enhance GET /admin/rate-limits to return per-IP breakdown, 24h throttled count, and hourly time-series of rejected requests - Add AdminRateLimits React page with auto-refresh (configurable interval), summary cards, throttled-over-time bar chart, endpoint table, and per-IP breakdown table - Add TypeScript types (RateLimitStatsResponse, etc.) and adminApi.getRateLimits() - Wire up /admin/rate-limits route and nav link (admin-only) - Expand unit tests: auth, empty state, per-IP, throttled_24h, time-series, response shape contract (10 tests total) Closes leeworks-agents/SPARC#1686 Co-Authored-By: Claude Opus 4.6 --- frontend/src/App.tsx | 11 ++ frontend/src/api/client.ts | 66 +++++++ frontend/src/components/Layout.tsx | 4 +- frontend/src/pages/AdminRateLimits.tsx | 240 +++++++++++++++++++++++++ tests/test_rate_limit_admin.py | 73 +++++++- 5 files changed, 391 insertions(+), 3 deletions(-) create mode 100644 frontend/src/pages/AdminRateLimits.tsx diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index d7ec5ba..d2fcc54 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -11,7 +11,9 @@ import { Batch } from './pages/Batch'; import { AnalyticsPage } from './pages/Analytics'; import { About } from './pages/About'; import { AdminUsers } from './pages/AdminUsers'; +import { AdminRateLimits } from './pages/AdminRateLimits'; import { Compare } from './pages/Compare'; +import { HistoryDiff } from './pages/HistoryDiff'; const queryClient = new QueryClient({ defaultOptions: { @@ -45,6 +47,7 @@ function App() { } /> } /> } /> + } /> } /> {/* Admin routes */} @@ -56,6 +59,14 @@ function App() { } /> + + + + } + /> {/* Default redirect */} diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 09a4ae6..b29a736 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -148,8 +148,43 @@ export const analysisApi = { const response = await api.get(`/jobs?${params}`); return response.data; }, + + getCompanyHistory: async (companyName: string, limit = 20): Promise => { + const response = await api.get( + `/analyze/${encodeURIComponent(companyName)}/history?limit=${limit}` + ); + return response.data; + }, + + diffAnalyses: async (companyName: string, fromId: number, toId: number): Promise => { + const response = await api.get( + `/analyze/${encodeURIComponent(companyName)}/diff?from=${fromId}&to=${toId}` + ); + return response.data; + }, }; +// Analysis diff types +export interface AnalysisHistoryItem { + id: number; + analysis_type: string | null; + model: string | null; + timestamp: string; +} + +export interface AnalysisDiff { + company_name: string; + from_id: number; + to_id: number; + from_timestamp: string; + to_timestamp: string; + patent_count_delta: number; + added_patents: string[]; + removed_patents: string[]; + changed_fields: Record; + summary: string; +} + // Export API export const exportApi = { exportCsv: async (companyName: string): Promise => { @@ -201,6 +236,32 @@ export const analyticsApi = { }, }; +// Rate limit types +export interface RateLimitIpEntry { + ip: string; + total: number; + rejected: number; +} + +export interface RateLimitEndpointStats { + endpoint: string; + limit: string; + total_requests: number; + rejected_requests: number; + by_ip: RateLimitIpEntry[]; +} + +export interface ThrottledBucket { + timestamp: string; + count: number; +} + +export interface RateLimitStatsResponse { + rate_limits: RateLimitEndpointStats[]; + throttled_24h: number; + throttled_over_time: ThrottledBucket[]; +} + // Admin API export const adminApi = { listUsers: async (limit = 100, offset = 0): Promise => { @@ -216,6 +277,11 @@ export const adminApi = { deleteUser: async (userId: number): Promise => { await api.delete(`/admin/users/${userId}`); }, + + getRateLimits: async (): Promise => { + const response = await api.get('/admin/rate-limits'); + return response.data; + }, }; export default api; diff --git a/frontend/src/components/Layout.tsx b/frontend/src/components/Layout.tsx index d0df715..d4b11b7 100644 --- a/frontend/src/components/Layout.tsx +++ b/frontend/src/components/Layout.tsx @@ -1,7 +1,7 @@ import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { useAuth } from '../context/AuthContext'; import { useTheme } from '../context/ThemeContext'; -import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react'; +import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, History, ShieldAlert } from 'lucide-react'; export function Layout() { const { user, isAdmin, logout } = useAuth(); @@ -18,11 +18,13 @@ export function Layout() { { to: '/batch', icon: Layers, label: 'Batch' }, { to: '/analytics', icon: BarChart3, label: 'Analytics' }, { to: '/compare', icon: GitCompareArrows, label: 'Compare' }, + { to: '/history-diff', icon: History, label: 'Diff' }, { to: '/about', icon: Info, label: 'About' }, ]; if (isAdmin) { navItems.push({ to: '/admin/users', icon: Users, label: 'Users' }); + navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' }); } return ( diff --git a/frontend/src/pages/AdminRateLimits.tsx b/frontend/src/pages/AdminRateLimits.tsx new file mode 100644 index 0000000..97b41c4 --- /dev/null +++ b/frontend/src/pages/AdminRateLimits.tsx @@ -0,0 +1,240 @@ +import { useState } from 'react'; +import { useQuery } from '@tanstack/react-query'; +import { adminApi } from '../api/client'; +import type { RateLimitStatsResponse } from '../api/client'; +import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react'; + +const REFRESH_OPTIONS = [ + { label: '15s', value: 15_000 }, + { label: '30s', value: 30_000 }, + { label: '1m', value: 60_000 }, + { label: 'Off', value: 0 }, +]; + +export function AdminRateLimits() { + const [refreshInterval, setRefreshInterval] = useState(30_000); + + const { data, isLoading, isError, dataUpdatedAt } = useQuery({ + queryKey: ['admin-rate-limits'], + queryFn: () => adminApi.getRateLimits(), + refetchInterval: refreshInterval || false, + }); + + if (isLoading) { + return ( +
+
+
+ ); + } + + if (isError) { + return ( +
+ + Failed to load rate limit statistics. +
+ ); + } + + const maxThrottledCount = data?.throttled_over_time?.length + ? Math.max(...data.throttled_over_time.map((b) => b.count)) + : 0; + + return ( +
+ {/* Header */} +
+
+

+ Rate Limiting Dashboard +

+

Monitor API rate limits and throttled requests.

+
+
+ {/* Last updated */} + {dataUpdatedAt > 0 && ( + + + Updated {new Date(dataUpdatedAt).toLocaleTimeString()} + + )} + {/* Refresh interval selector */} +
+ + {REFRESH_OPTIONS.map((opt) => ( + + ))} +
+
+
+ + {/* Summary cards */} +
+
+
+ + + Total Requests + +
+
+ {data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0} +
+
+
+
+ + + Throttled (24h) + +
+
+ {data?.throttled_24h ?? 0} +
+
+
+
+ + + Rate-Limited Endpoints + +
+
+ {data?.rate_limits.length ?? 0} +
+
+
+ + {/* Throttled over time chart (simple bar chart) */} + {data?.throttled_over_time && data.throttled_over_time.length > 0 && ( +
+

+ Throttled Requests Over Time (Last 24h) +

+
+ {data.throttled_over_time.map((bucket) => { + const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0; + const hour = new Date(bucket.timestamp).getHours(); + return ( +
+ {bucket.count} +
+ {hour}:00 +
+ ); + })} +
+
+ )} + + {/* Per-endpoint table */} +
+
+ + + + + + + + + + + {data?.rate_limits.map((rl) => ( + + + + + + + ))} + +
+ Endpoint + + Limit + + Total Requests + + Rejected +
{rl.endpoint} + + {rl.limit} + + + {rl.total_requests} + + 0 ? 'text-error font-semibold' : 'text-text-secondary'}> + {rl.rejected_requests} + +
+
+
+ + {/* Per-IP breakdown */} + {data?.rate_limits.some((rl) => rl.by_ip.length > 0) && ( +
+
+

+ Per-IP Breakdown +

+
+
+ + + + + + + + + + + {data.rate_limits.flatMap((rl) => + rl.by_ip.map((ipEntry) => ( + + + + + + + )) + )} + +
+ Endpoint + + IP Address + + Total + + Rejected +
{rl.endpoint}{ipEntry.ip}{ipEntry.total} + 0 ? 'text-error font-semibold' : 'text-text-secondary'}> + {ipEntry.rejected} + +
+
+
+ )} +
+ ); +} diff --git a/tests/test_rate_limit_admin.py b/tests/test_rate_limit_admin.py index bc63a5a..f10e9da 100644 --- a/tests/test_rate_limit_admin.py +++ b/tests/test_rate_limit_admin.py @@ -20,8 +20,10 @@ def client(): def reset_stats(): """Reset rate limit stats between tests.""" api._rate_limit_stats.clear() + api._rejected_log.clear() yield api._rate_limit_stats.clear() + api._rejected_log.clear() def _mock_admin(): @@ -50,8 +52,7 @@ class TestRateLimitAdminEndpoint: app.dependency_overrides.clear() def test_non_admin_rejected(self, client): - """Non-admin users should get 403.""" - # Without overriding the dependency, it should fail auth + """Non-admin users should get 401/403.""" response = client.get("/admin/rate-limits") assert response.status_code in (401, 403) @@ -77,6 +78,9 @@ class TestRateLimitAdminEndpoint: for rl in data["rate_limits"]: assert rl["total_requests"] == 0 assert rl["rejected_requests"] == 0 + assert rl["by_ip"] == [] + assert data["throttled_24h"] == 0 + assert data["throttled_over_time"] == [] finally: app.dependency_overrides.clear() @@ -107,3 +111,68 @@ class TestRateLimitAdminEndpoint: assert isinstance(rl["limit"], str) finally: app.dependency_overrides.clear() + + def test_per_ip_breakdown(self, client): + """Stats should include per-IP breakdown with total and rejected counts.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1") + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + api._track_rate_limit_request("/auth/login", "10.0.0.2") + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login") + by_ip = login_stats["by_ip"] + assert len(by_ip) == 2 + ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1") + assert ip1["total"] == 2 + assert ip1["rejected"] == 1 + ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2") + assert ip2["total"] == 1 + assert ip2["rejected"] == 0 + finally: + app.dependency_overrides.clear() + + def test_throttled_24h_count(self, client): + """Should report total throttled requests in the last 24 hours.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True) + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + assert data["throttled_24h"] == 2 + finally: + app.dependency_overrides.clear() + + def test_throttled_over_time_structure(self, client): + """Throttled-over-time should be a list of {timestamp, count} buckets.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + assert len(data["throttled_over_time"]) >= 1 + entry = data["throttled_over_time"][0] + assert "timestamp" in entry + assert "count" in entry + assert entry["count"] >= 1 + finally: + app.dependency_overrides.clear() + + def test_response_shape_matches_contract(self, client): + """The full response should match the expected shape for the frontend.""" + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + # Top-level keys + assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"} + # Each rate_limit entry + for rl in data["rate_limits"]: + assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"} + finally: + app.dependency_overrides.clear() -- 2.52.0 From 144d0fdf6a7a5c5bfa6a1e79bcb5350346eb2e57 Mon Sep 17 00:00:00 2001 From: agent-company Date: Tue, 19 May 2026 15:43:13 +0000 Subject: [PATCH 3/3] Add historical analysis diffing for same-company runs - Add GET /analyze/{company_name}/diff endpoint with from/to query params - Add GET /analyze/{company_name}/history endpoint for run selection - Add database methods get_analysis_by_id and list_company_analyses - Add frontend HistoryDiff page with run selector and diff visualization - Add Compare with previous button on Analysis results page - Add navigation link in Layout sidebar - Add 11 tests covering helpers, happy-path, and 404 scenarios Closes leeworks-agents/SPARC#1671 Co-Authored-By: Claude Opus 4.6 --- frontend/src/pages/Analysis.tsx | 11 +- frontend/src/pages/HistoryDiff.tsx | 249 +++++++++++++++++++++++++++++ tests/test_analysis_diff.py | 244 ++++++++++++++++++++++++++++ 3 files changed, 503 insertions(+), 1 deletion(-) create mode 100644 frontend/src/pages/HistoryDiff.tsx create mode 100644 tests/test_analysis_diff.py diff --git a/frontend/src/pages/Analysis.tsx b/frontend/src/pages/Analysis.tsx index 2f4fc35..bd6c2ea 100644 --- a/frontend/src/pages/Analysis.tsx +++ b/frontend/src/pages/Analysis.tsx @@ -1,10 +1,12 @@ import { useState } from 'react'; +import { useNavigate } from 'react-router-dom'; import { useMutation, useQuery } from '@tanstack/react-query'; import { analysisApi, exportApi } from '../api/client'; -import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react'; +import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown, History } from 'lucide-react'; import type { CompanyAnalysis } from '../types'; export function Analysis() { + const navigate = useNavigate(); const [companyName, setCompanyName] = useState(''); const [selectedModel, setSelectedModel] = useState(''); const [result, setResult] = useState(null); @@ -157,6 +159,13 @@ export function Analysis() { Export PDF +
diff --git a/frontend/src/pages/HistoryDiff.tsx b/frontend/src/pages/HistoryDiff.tsx new file mode 100644 index 0000000..9019d1d --- /dev/null +++ b/frontend/src/pages/HistoryDiff.tsx @@ -0,0 +1,249 @@ +import { useState } from 'react'; +import { useSearchParams } from 'react-router-dom'; +import { useQuery } from '@tanstack/react-query'; +import { analysisApi } from '../api/client'; +import type { AnalysisHistoryItem, AnalysisDiff } from '../api/client'; +import { History, ArrowRight, Plus, Minus, AlertCircle, Search } from 'lucide-react'; + +export function HistoryDiff() { + const [searchParams, setSearchParams] = useSearchParams(); + const [companyInput, setCompanyInput] = useState(searchParams.get('company') || ''); + + const company = searchParams.get('company') || ''; + const fromId = searchParams.get('from'); + const toId = searchParams.get('to'); + + // Fetch history when a company is selected + const historyQuery = useQuery({ + queryKey: ['history', company], + queryFn: () => analysisApi.getCompanyHistory(company), + enabled: !!company, + }); + + // Fetch diff when both IDs are selected + const diffQuery = useQuery({ + queryKey: ['diff', company, fromId, toId], + queryFn: () => analysisApi.diffAnalyses(company, Number(fromId), Number(toId)), + enabled: !!company && !!fromId && !!toId, + }); + + const handleSearch = (e: React.FormEvent) => { + e.preventDefault(); + const name = companyInput.trim(); + if (name) { + setSearchParams({ company: name }); + } + }; + + const handleSelectRuns = (from: number, to: number) => { + setSearchParams({ company, from: String(from), to: String(to) }); + }; + + const history: AnalysisHistoryItem[] = historyQuery.data || []; + + return ( +
+ {/* Header */} +
+

+ Historical Analysis Diff +

+

+ Compare analysis runs for the same company to see what changed between them. +

+
+ + {/* Company Search */} +
+
+ + setCompanyInput(e.target.value)} + placeholder="Enter company name (e.g., nvidia)" + className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all" + /> +
+ +
+ + {/* History list */} + {company && historyQuery.isLoading && ( +
Loading analysis history...
+ )} + + {company && historyQuery.isError && ( +
+ + Failed to load history. Check the company name and try again. +
+ )} + + {company && history.length === 0 && !historyQuery.isLoading && ( +
No analysis history found for "{company}".
+ )} + + {history.length >= 2 && ( +
+

+ Select Two Runs to Compare +

+
+ {history.map((item, idx) => { + const next = history[idx + 1]; + if (!next) return null; + const isSelected = + fromId === String(next.id) && toId === String(item.id); + return ( + + ); + })} +
+
+ )} + + {/* Diff Results */} + {diffQuery.isLoading && ( +
Computing diff...
+ )} + + {diffQuery.isError && ( +
+ + Failed to compute diff. One or both analysis IDs may not exist. +
+ )} + + {diffQuery.data && } +
+ ); +} + +function DiffView({ diff }: { diff: AnalysisDiff }) { + return ( +
+

+ Diff: #{diff.from_id} → #{diff.to_id} +

+ + {/* Summary */} +
+
{diff.summary}
+
+ {new Date(diff.from_timestamp).toLocaleString()} + + {new Date(diff.to_timestamp).toLocaleString()} +
+
+ + {/* Patent count delta */} +
+ Patent mention delta: + 0 + ? 'text-success' + : diff.patent_count_delta < 0 + ? 'text-error' + : 'text-text-secondary' + }`} + > + {diff.patent_count_delta > 0 ? '+' : ''} + {diff.patent_count_delta} + +
+ + {/* Added patents */} + {diff.added_patents.length > 0 && ( +
+

+ + New Patents ({diff.added_patents.length}) +

+
+ {diff.added_patents.map((p) => ( + + {p} + + ))} +
+
+ )} + + {/* Removed patents */} + {diff.removed_patents.length > 0 && ( +
+

+ + Removed Patents ({diff.removed_patents.length}) +

+
+ {diff.removed_patents.map((p) => ( + + {p} + + ))} +
+
+ )} + + {/* Changed fields */} + {Object.keys(diff.changed_fields).length > 0 && ( +
+

Changed Fields

+
+ {Object.entries(diff.changed_fields).map(([field, vals]) => ( +
+ {field}: + {vals.from || 'null'} + + {vals.to || 'null'} +
+ ))} +
+
+ )} +
+ ); +} diff --git a/tests/test_analysis_diff.py b/tests/test_analysis_diff.py new file mode 100644 index 0000000..c867b55 --- /dev/null +++ b/tests/test_analysis_diff.py @@ -0,0 +1,244 @@ +"""Tests for historical analysis diff endpoint.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import AnalysisDiffResponse, _compute_analysis_diff, _extract_patent_ids, app +from SPARC.auth import UserResponse, get_current_user + + +# ---------- helpers ---------- + +def _mock_user(): + """Return a fake authenticated user for dependency override.""" + return UserResponse( + id=1, + email="test@example.com", + role="user", + created_at=datetime(2025, 1, 1), + ) + + +@pytest.fixture +def auth_client(): + """TestClient with auth dependency overridden.""" + app.dependency_overrides[get_current_user] = _mock_user + client = TestClient(app, raise_server_exceptions=False) + yield client + app.dependency_overrides.clear() + + +# ---------- unit tests for helpers ---------- + +class TestExtractPatentIds: + """Test _extract_patent_ids utility.""" + + def test_extracts_standard_ids(self): + text = "Patent US-12345678-B2 covers the device. Also see US-9876543-A1." + ids = _extract_patent_ids(text) + assert "US-12345678-B2" in ids + assert "US-9876543-A1" in ids + + def test_empty_text(self): + assert _extract_patent_ids("") == set() + assert _extract_patent_ids(None) == set() # type: ignore[arg-type] + + +class TestComputeAnalysisDiff: + """Test _compute_analysis_diff logic.""" + + def test_identical_analyses(self): + rec = { + "id": 1, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "Patent US-12345678-B2 is notable.", + "timestamp": datetime(2025, 5, 1), + } + diff = _compute_analysis_diff(rec, dict(rec, id=2, timestamp=datetime(2025, 5, 2))) + assert diff.patent_count_delta == 0 + assert diff.added_patents == [] + assert diff.removed_patents == [] + + def test_added_and_removed_patents(self): + from_rec = { + "id": 1, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "Patent US-12345678-B2 and US-11111111-A1.", + "timestamp": datetime(2025, 5, 1), + } + to_rec = { + "id": 2, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "Patent US-12345678-B2 and US-99999999-B1.", + "timestamp": datetime(2025, 5, 2), + } + diff = _compute_analysis_diff(from_rec, to_rec) + assert "US-99999999-B1" in diff.added_patents + assert "US-11111111-A1" in diff.removed_patents + assert diff.patent_count_delta == 0 # one added, one removed + + def test_model_change_detected(self): + from_rec = { + "id": 1, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "", + "timestamp": datetime(2025, 5, 1), + } + to_rec = { + "id": 2, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "anthropic/claude-3.5-sonnet", + "response": "", + "timestamp": datetime(2025, 5, 2), + } + diff = _compute_analysis_diff(from_rec, to_rec) + assert "model" in diff.changed_fields + assert diff.changed_fields["model"]["from"] == "openai/gpt-4o" + assert diff.changed_fields["model"]["to"] == "anthropic/claude-3.5-sonnet" + + +# ---------- API endpoint tests ---------- + +class TestDiffEndpoint: + """Test GET /analyze/{company_name}/diff.""" + + @patch("SPARC.api._get_job_db") + def test_happy_path(self, mock_get_db, auth_client): + """Diff returns structured response when both IDs exist.""" + db = MagicMock() + mock_get_db.return_value = db + + from_rec = { + "id": 10, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "Patent US-12345678-B2 found.", + "timestamp": datetime(2025, 5, 1), + } + to_rec = { + "id": 20, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "Patent US-12345678-B2 and US-99999999-A1 found.", + "timestamp": datetime(2025, 5, 10), + } + db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else to_rec + + response = auth_client.get("/analyze/nvidia/diff?from=10&to=20") + assert response.status_code == 200 + data = response.json() + assert data["company_name"] == "nvidia" + assert data["from_id"] == 10 + assert data["to_id"] == 20 + assert "US-99999999-A1" in data["added_patents"] + assert data["patent_count_delta"] == 1 + + @patch("SPARC.api._get_job_db") + def test_from_id_not_found(self, mock_get_db, auth_client): + """Returns 404 when 'from' analysis ID doesn't exist.""" + db = MagicMock() + mock_get_db.return_value = db + db.get_analysis_by_id.return_value = None + + response = auth_client.get("/analyze/nvidia/diff?from=999&to=1000") + assert response.status_code == 404 + assert "999" in response.json()["detail"] + + @patch("SPARC.api._get_job_db") + def test_to_id_not_found(self, mock_get_db, auth_client): + """Returns 404 when 'to' analysis ID doesn't exist.""" + db = MagicMock() + mock_get_db.return_value = db + + from_rec = { + "id": 10, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "", + "timestamp": datetime(2025, 5, 1), + } + db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else None + + response = auth_client.get("/analyze/nvidia/diff?from=10&to=999") + assert response.status_code == 404 + assert "999" in response.json()["detail"] + + @patch("SPARC.api._get_job_db") + def test_company_mismatch(self, mock_get_db, auth_client): + """Returns 404 when analysis belongs to a different company.""" + db = MagicMock() + mock_get_db.return_value = db + + rec = { + "id": 10, + "company_name": "intel", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "", + "timestamp": datetime(2025, 5, 1), + } + db.get_analysis_by_id.return_value = rec + + response = auth_client.get("/analyze/nvidia/diff?from=10&to=20") + assert response.status_code == 404 + + +class TestHistoryEndpoint: + """Test GET /analyze/{company_name}/history.""" + + @patch("SPARC.api._get_job_db") + def test_returns_history_list(self, mock_get_db, auth_client): + """History endpoint returns list of past analysis runs.""" + db = MagicMock() + mock_get_db.return_value = db + db.list_company_analyses.return_value = [ + { + "id": 20, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "...", + "timestamp": datetime(2025, 5, 10), + }, + { + "id": 10, + "company_name": "nvidia", + "analysis_type": "portfolio", + "model": "openai/gpt-4o", + "response": "...", + "timestamp": datetime(2025, 5, 1), + }, + ] + + response = auth_client.get("/analyze/nvidia/history") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["id"] == 20 + assert data[1]["id"] == 10 + + @patch("SPARC.api._get_job_db") + def test_empty_history(self, mock_get_db, auth_client): + """History endpoint returns empty list when no analyses exist.""" + db = MagicMock() + mock_get_db.return_value = db + db.list_company_analyses.return_value = [] + + response = auth_client.get("/analyze/nvidia/history") + assert response.status_code == 200 + assert response.json() == [] -- 2.52.0