diff --git a/SPARC/api.py b/SPARC/api.py index 1b29d38..ea0674a 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios. from __future__ import annotations +from collections import deque from contextlib import asynccontextmanager -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Annotated, List if TYPE_CHECKING: @@ -29,9 +30,11 @@ from SPARC.auth import ( close_db_client, create_tokens, decode_token, + generate_api_key, get_current_admin, get_current_user, get_db_client, + hash_api_key, init_db_client, ) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -139,6 +142,31 @@ class HealthResponse(BaseModel): timestamp: datetime +# Historical diff models +class AnalysisDiffResponse(BaseModel): + """Response model for diffing two analysis runs of the same company.""" + + company_name: str + from_id: int + to_id: int + from_timestamp: datetime + to_timestamp: datetime + patent_count_delta: int + added_patents: list[str] + removed_patents: list[str] + changed_fields: dict[str, dict] + summary: str + + +class CompanyAnalysisHistoryItem(BaseModel): + """A summary item from a company's analysis history.""" + + id: int + analysis_type: str | None = None + model: str | None = None + timestamp: datetime + + # Auth request/response models class RegisterRequest(BaseModel): """User registration request.""" @@ -248,6 +276,9 @@ app.state.limiter = limiter # In-memory rate limit statistics _rate_limit_stats: dict[str, dict] = {} +# Time-series log of rejected requests (capped to last 24 h worth of entries). +_rejected_log: deque[dict] = deque(maxlen=100_000) + def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: """Record a request against a rate-limited endpoint.""" @@ -262,6 +293,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> _rate_limit_stats[key]["total_requests"] += 1 if rejected: _rate_limit_stats[key]["rejected_requests"] += 1 + _rejected_log.append({ + "endpoint": endpoint, + "ip": ip, + "timestamp": datetime.now(timezone.utc).isoformat(), + }) ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) if ip not in ip_stats: ip_stats[ip] = {"total": 0, "rejected": 0} @@ -378,6 +414,92 @@ async def get_me(current_user: UserResponse = Depends(get_current_user)): return current_user +# ============== API Key Endpoints ============== + + +class CreateApiKeyRequest(BaseModel): + """Request to create a new API key.""" + + label: str | None = Field(default=None, max_length=100, description="Optional label for the key") + + +class ApiKeyResponse(BaseModel): + """Response after creating an API key (includes plaintext key).""" + + id: int + key: str # plaintext key, shown only at creation time + label: str | None = None + created_at: datetime + + +class ApiKeyInfo(BaseModel): + """API key metadata (no secret).""" + + id: int + label: str | None = None + created_at: datetime + + +@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"]) +async def create_api_key_endpoint( + body: CreateApiKeyRequest | None = None, + current_user: UserResponse = Depends(get_current_user), +): + """Generate a new API key for the authenticated user. + + The plaintext key is returned **only once** in the response. + Store it securely; it cannot be retrieved again. + """ + plaintext_key = generate_api_key() + key_hash = hash_api_key(plaintext_key) + + db = get_db_client() + label = body.label if body else None + row = db.create_api_key( + user_id=current_user.id, + key_hash=key_hash, + label=label, + ) + + return ApiKeyResponse( + id=row["id"], + key=plaintext_key, + label=row["label"], + created_at=row["created_at"], + ) + + +@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"]) +async def list_api_keys_endpoint( + current_user: UserResponse = Depends(get_current_user), +): + """List active API key IDs and labels for the authenticated user. + + Does **not** return the secret keys. + """ + db = get_db_client() + keys = db.list_api_keys(current_user.id) + return [ApiKeyInfo(**k) for k in keys] + + +@app.delete("/auth/apikeys/{key_id}", tags=["Auth"]) +async def revoke_api_key_endpoint( + key_id: int, + current_user: UserResponse = Depends(get_current_user), +): + """Revoke (delete) an API key by its ID. + + The key must belong to the authenticated user. + """ + db = get_db_client() + deleted = db.delete_api_key(key_id, current_user.id) + + if not deleted: + raise HTTPException(status_code=404, detail="API key not found") + + return {"message": "API key revoked"} + + # ============== Admin Endpoints ============== @@ -507,10 +629,12 @@ async def get_rate_limit_stats( """Get rate limit status and usage statistics (admin only). Returns current rate limit configuration and request statistics - for all rate-limited endpoints. + for all rate-limited endpoints, including per-IP breakdown and + a time-series of throttled (rejected) requests in the last 24 hours. Returns: - List of rate limit stats per endpoint with total/rejected counts + Rate limit stats per endpoint, per-IP breakdown, and throttled + request history bucketed by hour. """ rate_limits_config = { "/auth/register": {"limit": "5/minute"}, @@ -520,14 +644,45 @@ async def get_rate_limit_stats( results = [] for endpoint, conf in rate_limits_config.items(): stats = _rate_limit_stats.get(endpoint, {}) + by_ip_raw = stats.get("by_ip", {}) + by_ip = [ + {"ip": ip, "total": counts["total"], "rejected": counts["rejected"]} + for ip, counts in by_ip_raw.items() + ] results.append({ "endpoint": endpoint, "limit": conf["limit"], "total_requests": stats.get("total_requests", 0), "rejected_requests": stats.get("rejected_requests", 0), + "by_ip": by_ip, }) - return {"rate_limits": results} + # Build hourly buckets of throttled requests for the last 24 hours + now = datetime.now(timezone.utc) + cutoff = now - timedelta(hours=24) + hourly_buckets: dict[str, int] = {} + throttled_24h = 0 + for entry in _rejected_log: + ts_str = entry["timestamp"] + try: + ts = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + continue + if ts >= cutoff: + throttled_24h += 1 + bucket = ts.strftime("%Y-%m-%dT%H:00:00Z") + hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1 + + throttled_over_time = [ + {"timestamp": k, "count": v} + for k, v in sorted(hourly_buckets.items()) + ] + + return { + "rate_limits": results, + "throttled_24h": throttled_24h, + "throttled_over_time": throttled_over_time, + } @app.get("/admin/alerts", tags=["Admin"]) @@ -927,6 +1082,147 @@ async def analyze_company( return _convert_result(result) +def _extract_patent_ids(response_text: str) -> set[str]: + """Extract patent IDs from an analysis response text. + + Looks for patterns like US-12345678-B2, US12345678B2, etc. + """ + import re + pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?" + return set(re.findall(pattern, response_text or "")) + + +def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse: + """Compute a structured diff between two analysis records.""" + from_patents = _extract_patent_ids(from_rec.get("response", "") or "") + to_patents = _extract_patent_ids(to_rec.get("response", "") or "") + + added = sorted(to_patents - from_patents) + removed = sorted(from_patents - to_patents) + + patent_count_delta = len(to_patents) - len(from_patents) + + changed_fields: dict[str, dict] = {} + if from_rec.get("model") != to_rec.get("model"): + changed_fields["model"] = { + "from": from_rec.get("model"), + "to": to_rec.get("model"), + } + if from_rec.get("analysis_type") != to_rec.get("analysis_type"): + changed_fields["analysis_type"] = { + "from": from_rec.get("analysis_type"), + "to": to_rec.get("analysis_type"), + } + + # Build a human-readable summary + parts: list[str] = [] + if added: + parts.append(f"{len(added)} new patent(s) appeared") + if removed: + parts.append(f"{len(removed)} patent(s) no longer referenced") + if patent_count_delta > 0: + parts.append(f"patent mention count increased by {patent_count_delta}") + elif patent_count_delta < 0: + parts.append(f"patent mention count decreased by {abs(patent_count_delta)}") + if changed_fields: + parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}") + summary = "; ".join(parts) if parts else "No significant differences detected." + + return AnalysisDiffResponse( + company_name=to_rec["company_name"], + from_id=from_rec["id"], + to_id=to_rec["id"], + from_timestamp=from_rec["timestamp"], + to_timestamp=to_rec["timestamp"], + patent_count_delta=patent_count_delta, + added_patents=added, + removed_patents=removed, + changed_fields=changed_fields, + summary=summary, + ) + + +@app.get( + "/analyze/{company_name}/history", + response_model=list[CompanyAnalysisHistoryItem], + tags=["Analysis"], +) +async def list_company_analysis_history( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + limit: int = Query(default=20, ge=1, le=100), + _: UserResponse = Depends(get_current_user), +): + """List previous analysis runs for a company. + + Returns a list of analysis records ordered by timestamp descending, + useful for selecting which runs to compare via the diff endpoint. + + Args: + company_name: Company name to look up + limit: Maximum number of results + + Returns: + List of analysis history items + """ + db = _get_job_db() + rows = db.list_company_analyses(company_name, limit=limit) + return [ + CompanyAnalysisHistoryItem( + id=r["id"], + analysis_type=r.get("analysis_type"), + model=r.get("model"), + timestamp=r["timestamp"], + ) + for r in rows + ] + + +@app.get( + "/analyze/{company_name}/diff", + response_model=AnalysisDiffResponse, + tags=["Analysis"], +) +async def diff_company_analyses( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + from_id: int = Query(..., alias="from", description="Analysis ID of the older run"), + to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"), + _: UserResponse = Depends(get_current_user), +): + """Compare two analysis runs for the same company. + + Returns a structured diff showing added/removed patents, score delta, + and a summary narrative. + + Args: + company_name: Company name (must match both analysis records) + from_id: ID of the older analysis run + to_id: ID of the newer analysis run + + Returns: + AnalysisDiffResponse with added/removed/changed fields + + Raises: + 404: If either analysis ID does not exist or belongs to a different company + """ + db = _get_job_db() + + from_rec = db.get_analysis_by_id(from_id) + if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower(): + raise HTTPException( + status_code=404, + detail=f"Analysis ID {from_id} not found for company '{company_name}'", + ) + + to_rec = db.get_analysis_by_id(to_id) + if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower(): + raise HTTPException( + status_code=404, + detail=f"Analysis ID {to_id} not found for company '{company_name}'", + ) + + return _compute_analysis_diff(from_rec, to_rec) + + @app.get( "/analyze/patent/{patent_id}", tags=["Analysis"], diff --git a/SPARC/auth.py b/SPARC/auth.py index 890d286..932ae53 100644 --- a/SPARC/auth.py +++ b/SPARC/auth.py @@ -1,11 +1,13 @@ -"""JWT authentication utilities for SPARC API.""" +"""JWT and API key authentication utilities for SPARC API.""" import os +import secrets from datetime import datetime, timedelta, timezone from typing import Optional +import bcrypt import jwt -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel @@ -32,7 +34,7 @@ def check_jwt_secret() -> None: "Set a secure JWT_SECRET environment variable before running in non-development environments." ) -security = HTTPBearer() +security = HTTPBearer(auto_error=False) class TokenPayload(BaseModel): @@ -178,20 +180,107 @@ def get_db_client() -> DatabaseClient: return _db_client -async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), -) -> UserResponse: - """Get the current authenticated user from JWT token. +def generate_api_key() -> str: + """Generate a random 32-byte hex API key. + + Returns: + 64-character hex string + """ + return secrets.token_hex(32) + + +def hash_api_key(key: str) -> str: + """Hash an API key using bcrypt. Args: - credentials: Bearer token from request + key: Plaintext API key + + Returns: + bcrypt hash string + """ + return bcrypt.hashpw(key.encode(), bcrypt.gensalt()).decode() + + +def verify_api_key(key: str, key_hash: str) -> bool: + """Verify a plaintext API key against its bcrypt hash. + + Args: + key: Plaintext API key + key_hash: Stored bcrypt hash + + Returns: + True if key matches + """ + return bcrypt.checkpw(key.encode(), key_hash.encode()) + + +def _authenticate_via_api_key(api_key: str) -> Optional[UserResponse]: + """Look up a user by raw API key. + + Iterates over all stored key hashes (small table) and returns the + corresponding user when a match is found. + + Args: + api_key: Plaintext API key from X-API-Key header + + Returns: + UserResponse if valid key, None otherwise + """ + db = get_db_client() + key_rows = db.get_all_api_key_hashes() + + for row in key_rows: + if verify_api_key(api_key, row["key_hash"]): + user = db.get_user_by_id(row["user_id"]) + if user: + return UserResponse( + id=user["id"], + email=user["email"], + role=user["role"], + created_at=user["created_at"], + ) + return None + + +async def get_current_user( + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), +) -> UserResponse: + """Get the current authenticated user from JWT token or API key. + + Supports two authentication methods: + 1. Bearer JWT token via Authorization header + 2. API key via X-API-Key header + + Args: + request: The incoming request (used for X-API-Key header) + credentials: Optional Bearer token from request Returns: UserResponse with user details Raises: - HTTPException: If token is invalid or expired + HTTPException: If no valid credentials are provided """ + # Try X-API-Key header first + api_key = request.headers.get("X-API-Key") + if api_key: + user = _authenticate_via_api_key(api_key) + if user: + return user + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + # Fall back to JWT Bearer token + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = credentials.credentials payload = decode_token(token) diff --git a/SPARC/database.py b/SPARC/database.py index 0759a66..45e87e6 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -221,6 +221,27 @@ class DatabaseClient: ON alerts(company_name) """) + # Create API keys table for programmatic access + cursor.execute(""" + CREATE TABLE IF NOT EXISTS api_keys ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + key_hash VARCHAR(255) NOT NULL, + label VARCHAR(100), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_keys_user_id + ON api_keys(user_id) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash + ON api_keys(key_hash) + """) + self.conn.commit() @staticmethod @@ -977,3 +998,138 @@ class DatabaseClient: (limit,), ) return [dict(row) for row in cursor.fetchall()] + + # Historical Analysis Diff Methods + + def get_analysis_by_id(self, analysis_id: int) -> Optional[Dict]: + """Get a single analysis record by its ID. + + Args: + analysis_id: The primary key of the llm_messages row. + + Returns: + Dict with analysis fields, or None if not found. + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT id, company_name, analysis_type, model, response, timestamp + FROM llm_messages + WHERE id = %s AND is_cached = FALSE + """, + (analysis_id,), + ) + row = cursor.fetchone() + return dict(row) if row else None + + def list_company_analyses( + self, company_name: str, limit: int = 20 + ) -> List[Dict]: + """List past analysis runs for a given company. + + Returns records ordered by timestamp descending so callers can + identify which previous runs are available for diffing. + + Args: + company_name: Company name (case-insensitive match). + limit: Maximum number of records. + + Returns: + List of analysis dicts. + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT id, company_name, analysis_type, model, response, timestamp + FROM llm_messages + WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + ORDER BY timestamp DESC + LIMIT %s + """, + (company_name, limit), + ) + return [dict(row) for row in cursor.fetchall()] + + # API Key Methods + + def create_api_key( + self, + user_id: int, + key_hash: str, + label: Optional[str] = None, + ) -> Dict: + """Store a new API key hash for a user. + + Args: + user_id: The owning user's ID + key_hash: bcrypt hash of the plaintext key + label: Optional human-readable label + + Returns: + Dict with id, user_id, label, created_at + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + INSERT INTO api_keys (user_id, key_hash, label) + VALUES (%s, %s, %s) + RETURNING id, user_id, label, created_at + """, + (user_id, key_hash, label), + ) + row = cursor.fetchone() + conn.commit() + return dict(row) + + def list_api_keys(self, user_id: int) -> List[Dict]: + """List active API key metadata for a user (no secrets). + + Args: + user_id: The user's ID + + Returns: + List of dicts with id, label, created_at + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT id, label, created_at FROM api_keys WHERE user_id = %s ORDER BY created_at DESC", + (user_id,), + ) + return [dict(row) for row in cursor.fetchall()] + + def delete_api_key(self, key_id: int, user_id: int) -> bool: + """Revoke an API key by ID (must belong to user). + + Args: + key_id: The API key row ID + user_id: The owning user's ID + + Returns: + True if a key was deleted + """ + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + "DELETE FROM api_keys WHERE id = %s AND user_id = %s", + (key_id, user_id), + ) + deleted = cursor.rowcount > 0 + conn.commit() + return deleted + + def get_all_api_key_hashes(self) -> List[Dict]: + """Return all API key hashes with their associated user IDs. + + Used by the auth layer to validate an incoming API key. + + Returns: + List of dicts with key_hash, user_id + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute("SELECT key_hash, user_id FROM api_keys") + return [dict(row) for row in cursor.fetchall()] diff --git a/tests/test_api_keys.py b/tests/test_api_keys.py new file mode 100644 index 0000000..3942f63 --- /dev/null +++ b/tests/test_api_keys.py @@ -0,0 +1,319 @@ +"""Tests for user-level API key generation, listing, revocation, and authentication. + +Covers all acceptance criteria from issue #1673: +1. Users can create API keys (POST /auth/apikeys) +2. Users can list their active key IDs (GET /auth/apikeys) +3. Users can revoke keys (DELETE /auth/apikeys/{key_id}) +4. API requests authenticated with a valid API key work on protected endpoints +5. Revoked keys are immediately rejected +6. Plaintext key is shown only at creation time + +All tests use mocked DB fixtures and require no live database. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import ( + create_access_token, + generate_api_key, + hash_api_key, + verify_api_key, +) + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +def _make_user(): + return { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + +def _auth_header(user_dict): + """Create an Authorization header with a valid access token.""" + token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"]) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture(autouse=True) +def mock_db(monkeypatch): + """Mock the database client used by auth and api endpoints.""" + db = MagicMock() + db.get_user_count.return_value = 0 + db.get_user_by_id.return_value = None + db.get_user_by_email.return_value = None + db.authenticate_user.return_value = None + db.create_user.return_value = None + db.get_all_users.return_value = [] + db.update_user_role.return_value = None + db.delete_user.return_value = False + db.create_api_key.return_value = None + db.list_api_keys.return_value = [] + db.delete_api_key.return_value = False + db.get_all_api_key_hashes.return_value = [] + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +class TestCreateApiKey: + """POST /auth/apikeys""" + + def test_create_key_returns_plaintext_and_id(self, client, mock_db): + """Creating a key returns the plaintext key and metadata.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.create_api_key.return_value = { + "id": 42, + "user_id": user["id"], + "label": "my-ci-key", + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/auth/apikeys", + json={"label": "my-ci-key"}, + headers=_auth_header(user), + ) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == 42 + assert len(data["key"]) == 64 # 32 bytes hex = 64 chars + assert data["label"] == "my-ci-key" + assert "created_at" in data + + # Verify the hash passed to DB is valid for the returned key + call_args = mock_db.create_api_key.call_args + stored_hash = call_args.kwargs.get("key_hash") or call_args[1].get("key_hash") or call_args[0][1] + assert verify_api_key(data["key"], stored_hash) + + def test_create_key_without_label(self, client, mock_db): + """Creating a key without a label should work.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.create_api_key.return_value = { + "id": 1, + "user_id": user["id"], + "label": None, + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + response = client.post( + "/auth/apikeys", + headers=_auth_header(user), + ) + + assert response.status_code == 200 + assert response.json()["label"] is None + + def test_create_key_requires_auth(self, client): + """Creating a key without auth should fail.""" + response = client.post("/auth/apikeys") + assert response.status_code == 401 + + +class TestListApiKeys: + """GET /auth/apikeys""" + + def test_list_keys_returns_metadata_only(self, client, mock_db): + """Listing keys should return IDs and labels, not secrets.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.list_api_keys.return_value = [ + {"id": 1, "label": "key-1", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc)}, + {"id": 2, "label": None, "created_at": datetime(2025, 2, 1, tzinfo=timezone.utc)}, + ] + + response = client.get("/auth/apikeys", headers=_auth_header(user)) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["id"] == 1 + assert data[0]["label"] == "key-1" + # Ensure no secret key is exposed + for item in data: + assert "key" not in item + assert "key_hash" not in item + + def test_list_keys_empty(self, client, mock_db): + """User with no keys gets an empty list.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.list_api_keys.return_value = [] + + response = client.get("/auth/apikeys", headers=_auth_header(user)) + + assert response.status_code == 200 + assert response.json() == [] + + +class TestRevokeApiKey: + """DELETE /auth/apikeys/{key_id}""" + + def test_revoke_existing_key(self, client, mock_db): + """Revoking an owned key should succeed.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.delete_api_key.return_value = True + + response = client.delete("/auth/apikeys/42", headers=_auth_header(user)) + + assert response.status_code == 200 + assert "revoked" in response.json()["message"].lower() + mock_db.delete_api_key.assert_called_once_with(42, user["id"]) + + def test_revoke_nonexistent_key_returns_404(self, client, mock_db): + """Revoking a key that doesn't exist (or isn't owned) returns 404.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + mock_db.delete_api_key.return_value = False + + response = client.delete("/auth/apikeys/999", headers=_auth_header(user)) + + assert response.status_code == 404 + + +class TestApiKeyAuthentication: + """Using X-API-Key header on protected endpoints.""" + + def test_valid_api_key_accesses_protected_endpoint(self, client, mock_db): + """A valid API key should authenticate and access /auth/me.""" + user = _make_user() + plaintext = generate_api_key() + hashed = hash_api_key(plaintext) + + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": hashed, "user_id": user["id"]}, + ] + mock_db.get_user_by_id.return_value = user + + response = client.get("/auth/me", headers={"X-API-Key": plaintext}) + + assert response.status_code == 200 + data = response.json() + assert data["email"] == user["email"] + assert data["id"] == user["id"] + + def test_invalid_api_key_returns_401(self, client, mock_db): + """An invalid API key should return 401.""" + mock_db.get_all_api_key_hashes.return_value = [] + + response = client.get("/auth/me", headers={"X-API-Key": "bad-key"}) + + assert response.status_code == 401 + assert "invalid api key" in response.json()["detail"].lower() + + def test_revoked_key_returns_401(self, client, mock_db): + """After revocation, using the key should return 401.""" + # Simulate revoked key: no matching hashes in DB + mock_db.get_all_api_key_hashes.return_value = [] + + response = client.get("/auth/me", headers={"X-API-Key": "a" * 64}) + + assert response.status_code == 401 + + def test_api_key_for_deleted_user_returns_401(self, client, mock_db): + """An API key whose user no longer exists should return 401.""" + plaintext = generate_api_key() + hashed = hash_api_key(plaintext) + + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": hashed, "user_id": 999}, + ] + mock_db.get_user_by_id.return_value = None # user deleted + + response = client.get("/auth/me", headers={"X-API-Key": plaintext}) + + assert response.status_code == 401 + + def test_no_auth_at_all_returns_401(self, client, mock_db): + """No auth header at all should return 401.""" + response = client.get("/auth/me") + assert response.status_code == 401 + + +class TestApiKeyFullFlow: + """End-to-end flow: create key, use it, revoke it, try again.""" + + def test_create_use_revoke_flow(self, client, mock_db): + """Simulate full lifecycle of an API key.""" + user = _make_user() + mock_db.get_user_by_id.return_value = user + + # Step 1: Create key + mock_db.create_api_key.return_value = { + "id": 10, + "user_id": user["id"], + "label": "test", + "created_at": datetime(2025, 6, 1, tzinfo=timezone.utc), + } + + create_resp = client.post( + "/auth/apikeys", + json={"label": "test"}, + headers=_auth_header(user), + ) + assert create_resp.status_code == 200 + plaintext = create_resp.json()["key"] + + # Capture the hash that was stored + call_args = mock_db.create_api_key.call_args + stored_hash = call_args.kwargs.get("key_hash") or call_args[0][1] + + # Step 2: Use key on protected endpoint + mock_db.get_all_api_key_hashes.return_value = [ + {"key_hash": stored_hash, "user_id": user["id"]}, + ] + + use_resp = client.get("/auth/me", headers={"X-API-Key": plaintext}) + assert use_resp.status_code == 200 + assert use_resp.json()["email"] == user["email"] + + # Step 3: Revoke key + mock_db.delete_api_key.return_value = True + revoke_resp = client.delete("/auth/apikeys/10", headers=_auth_header(user)) + assert revoke_resp.status_code == 200 + + # Step 4: Try using revoked key + mock_db.get_all_api_key_hashes.return_value = [] # key removed from DB + rejected_resp = client.get("/auth/me", headers={"X-API-Key": plaintext}) + assert rejected_resp.status_code == 401 + + +class TestApiKeyHelpers: + """Unit tests for key generation and hashing helpers.""" + + def test_generate_api_key_length(self): + """Generated key should be 64 hex characters (32 bytes).""" + key = generate_api_key() + assert len(key) == 64 + # Should be valid hex + int(key, 16) + + def test_generate_api_key_uniqueness(self): + """Two generated keys should be different.""" + k1 = generate_api_key() + k2 = generate_api_key() + assert k1 != k2 + + def test_hash_and_verify(self): + """hash_api_key and verify_api_key should round-trip correctly.""" + key = generate_api_key() + hashed = hash_api_key(key) + assert verify_api_key(key, hashed) + assert not verify_api_key("wrong-key", hashed)