Add historical analysis diffing between runs for same company #1695

Open
AI-Manager wants to merge 3 commits from feature/historical-analysis-diff into main
4 changed files with 873 additions and 13 deletions
Showing only changes of commit 3d8922366e - Show all commits
+300 -4
View File
@@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING:
@@ -29,9 +30,11 @@ from SPARC.auth import (
close_db_client,
create_tokens,
decode_token,
generate_api_key,
get_current_admin,
get_current_user,
get_db_client,
hash_api_key,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -139,6 +142,31 @@ class HealthResponse(BaseModel):
timestamp: datetime
# Historical diff models
class AnalysisDiffResponse(BaseModel):
"""Response model for diffing two analysis runs of the same company."""
company_name: str
from_id: int
to_id: int
from_timestamp: datetime
to_timestamp: datetime
patent_count_delta: int
added_patents: list[str]
removed_patents: list[str]
changed_fields: dict[str, dict]
summary: str
class CompanyAnalysisHistoryItem(BaseModel):
"""A summary item from a company's analysis history."""
id: int
analysis_type: str | None = None
model: str | None = None
timestamp: datetime
# Auth request/response models
class RegisterRequest(BaseModel):
"""User registration request."""
@@ -248,6 +276,9 @@ app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
@@ -262,6 +293,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) ->
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
@@ -378,6 +414,92 @@ async def get_me(current_user: UserResponse = Depends(get_current_user)):
return current_user
# ============== API Key Endpoints ==============
class CreateApiKeyRequest(BaseModel):
"""Request to create a new API key."""
label: str | None = Field(default=None, max_length=100, description="Optional label for the key")
class ApiKeyResponse(BaseModel):
"""Response after creating an API key (includes plaintext key)."""
id: int
key: str # plaintext key, shown only at creation time
label: str | None = None
created_at: datetime
class ApiKeyInfo(BaseModel):
"""API key metadata (no secret)."""
id: int
label: str | None = None
created_at: datetime
@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"])
async def create_api_key_endpoint(
body: CreateApiKeyRequest | None = None,
current_user: UserResponse = Depends(get_current_user),
):
"""Generate a new API key for the authenticated user.
The plaintext key is returned **only once** in the response.
Store it securely; it cannot be retrieved again.
"""
plaintext_key = generate_api_key()
key_hash = hash_api_key(plaintext_key)
db = get_db_client()
label = body.label if body else None
row = db.create_api_key(
user_id=current_user.id,
key_hash=key_hash,
label=label,
)
return ApiKeyResponse(
id=row["id"],
key=plaintext_key,
label=row["label"],
created_at=row["created_at"],
)
@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"])
async def list_api_keys_endpoint(
current_user: UserResponse = Depends(get_current_user),
):
"""List active API key IDs and labels for the authenticated user.
Does **not** return the secret keys.
"""
db = get_db_client()
keys = db.list_api_keys(current_user.id)
return [ApiKeyInfo(**k) for k in keys]
@app.delete("/auth/apikeys/{key_id}", tags=["Auth"])
async def revoke_api_key_endpoint(
key_id: int,
current_user: UserResponse = Depends(get_current_user),
):
"""Revoke (delete) an API key by its ID.
The key must belong to the authenticated user.
"""
db = get_db_client()
deleted = db.delete_api_key(key_id, current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="API key not found")
return {"message": "API key revoked"}
# ============== Admin Endpoints ==============
@@ -507,10 +629,12 @@ async def get_rate_limit_stats(
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints.
for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns:
List of rate limit stats per endpoint with total/rejected counts
Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
@@ -520,14 +644,45 @@ async def get_rate_limit_stats(
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
})
return {"rate_limits": results}
# Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"])
@@ -927,6 +1082,147 @@ async def analyze_company(
return _convert_result(result)
def _extract_patent_ids(response_text: str) -> set[str]:
"""Extract patent IDs from an analysis response text.
Looks for patterns like US-12345678-B2, US12345678B2, etc.
"""
import re
pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?"
return set(re.findall(pattern, response_text or ""))
def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse:
"""Compute a structured diff between two analysis records."""
from_patents = _extract_patent_ids(from_rec.get("response", "") or "")
to_patents = _extract_patent_ids(to_rec.get("response", "") or "")
added = sorted(to_patents - from_patents)
removed = sorted(from_patents - to_patents)
patent_count_delta = len(to_patents) - len(from_patents)
changed_fields: dict[str, dict] = {}
if from_rec.get("model") != to_rec.get("model"):
changed_fields["model"] = {
"from": from_rec.get("model"),
"to": to_rec.get("model"),
}
if from_rec.get("analysis_type") != to_rec.get("analysis_type"):
changed_fields["analysis_type"] = {
"from": from_rec.get("analysis_type"),
"to": to_rec.get("analysis_type"),
}
# Build a human-readable summary
parts: list[str] = []
if added:
parts.append(f"{len(added)} new patent(s) appeared")
if removed:
parts.append(f"{len(removed)} patent(s) no longer referenced")
if patent_count_delta > 0:
parts.append(f"patent mention count increased by {patent_count_delta}")
elif patent_count_delta < 0:
parts.append(f"patent mention count decreased by {abs(patent_count_delta)}")
if changed_fields:
parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}")
summary = "; ".join(parts) if parts else "No significant differences detected."
return AnalysisDiffResponse(
company_name=to_rec["company_name"],
from_id=from_rec["id"],
to_id=to_rec["id"],
from_timestamp=from_rec["timestamp"],
to_timestamp=to_rec["timestamp"],
patent_count_delta=patent_count_delta,
added_patents=added,
removed_patents=removed,
changed_fields=changed_fields,
summary=summary,
)
@app.get(
"/analyze/{company_name}/history",
response_model=list[CompanyAnalysisHistoryItem],
tags=["Analysis"],
)
async def list_company_analysis_history(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
limit: int = Query(default=20, ge=1, le=100),
_: UserResponse = Depends(get_current_user),
):
"""List previous analysis runs for a company.
Returns a list of analysis records ordered by timestamp descending,
useful for selecting which runs to compare via the diff endpoint.
Args:
company_name: Company name to look up
limit: Maximum number of results
Returns:
List of analysis history items
"""
db = _get_job_db()
rows = db.list_company_analyses(company_name, limit=limit)
return [
CompanyAnalysisHistoryItem(
id=r["id"],
analysis_type=r.get("analysis_type"),
model=r.get("model"),
timestamp=r["timestamp"],
)
for r in rows
]
@app.get(
"/analyze/{company_name}/diff",
response_model=AnalysisDiffResponse,
tags=["Analysis"],
)
async def diff_company_analyses(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
from_id: int = Query(..., alias="from", description="Analysis ID of the older run"),
to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"),
_: UserResponse = Depends(get_current_user),
):
"""Compare two analysis runs for the same company.
Returns a structured diff showing added/removed patents, score delta,
and a summary narrative.
Args:
company_name: Company name (must match both analysis records)
from_id: ID of the older analysis run
to_id: ID of the newer analysis run
Returns:
AnalysisDiffResponse with added/removed/changed fields
Raises:
404: If either analysis ID does not exist or belongs to a different company
"""
db = _get_job_db()
from_rec = db.get_analysis_by_id(from_id)
if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {from_id} not found for company '{company_name}'",
)
to_rec = db.get_analysis_by_id(to_id)
if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {to_id} not found for company '{company_name}'",
)
return _compute_analysis_diff(from_rec, to_rec)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
+98 -9
View File
@@ -1,11 +1,13 @@
"""JWT authentication utilities for SPARC API."""
"""JWT and API key authentication utilities for SPARC API."""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
@@ -32,7 +34,7 @@ def check_jwt_secret() -> None:
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer()
security = HTTPBearer(auto_error=False)
class TokenPayload(BaseModel):
@@ -178,20 +180,107 @@ def get_db_client() -> DatabaseClient:
return _db_client
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token.
def generate_api_key() -> str:
"""Generate a random 32-byte hex API key.
Returns:
64-character hex string
"""
return secrets.token_hex(32)
def hash_api_key(key: str) -> str:
"""Hash an API key using bcrypt.
Args:
credentials: Bearer token from request
key: Plaintext API key
Returns:
bcrypt hash string
"""
return bcrypt.hashpw(key.encode(), bcrypt.gensalt()).decode()
def verify_api_key(key: str, key_hash: str) -> bool:
"""Verify a plaintext API key against its bcrypt hash.
Args:
key: Plaintext API key
key_hash: Stored bcrypt hash
Returns:
True if key matches
"""
return bcrypt.checkpw(key.encode(), key_hash.encode())
def _authenticate_via_api_key(api_key: str) -> Optional[UserResponse]:
"""Look up a user by raw API key.
Iterates over all stored key hashes (small table) and returns the
corresponding user when a match is found.
Args:
api_key: Plaintext API key from X-API-Key header
Returns:
UserResponse if valid key, None otherwise
"""
db = get_db_client()
key_rows = db.get_all_api_key_hashes()
for row in key_rows:
if verify_api_key(api_key, row["key_hash"]):
user = db.get_user_by_id(row["user_id"])
if user:
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
return None
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token or API key.
Supports two authentication methods:
1. Bearer JWT token via Authorization header
2. API key via X-API-Key header
Args:
request: The incoming request (used for X-API-Key header)
credentials: Optional Bearer token from request
Returns:
UserResponse with user details
Raises:
HTTPException: If token is invalid or expired
HTTPException: If no valid credentials are provided
"""
# Try X-API-Key header first
api_key = request.headers.get("X-API-Key")
if api_key:
user = _authenticate_via_api_key(api_key)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
# Fall back to JWT Bearer token
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = decode_token(token)
+156
View File
@@ -221,6 +221,27 @@ class DatabaseClient:
ON alerts(company_name)
""")
# Create API keys table for programmatic access
cursor.execute("""
CREATE TABLE IF NOT EXISTS api_keys (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_hash VARCHAR(255) NOT NULL,
label VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
ON api_keys(user_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash
ON api_keys(key_hash)
""")
self.conn.commit()
@staticmethod
@@ -977,3 +998,138 @@ class DatabaseClient:
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
# Historical Analysis Diff Methods
def get_analysis_by_id(self, analysis_id: int) -> Optional[Dict]:
"""Get a single analysis record by its ID.
Args:
analysis_id: The primary key of the llm_messages row.
Returns:
Dict with analysis fields, or None if not found.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE id = %s AND is_cached = FALSE
""",
(analysis_id,),
)
row = cursor.fetchone()
return dict(row) if row else None
def list_company_analyses(
self, company_name: str, limit: int = 20
) -> List[Dict]:
"""List past analysis runs for a given company.
Returns records ordered by timestamp descending so callers can
identify which previous runs are available for diffing.
Args:
company_name: Company name (case-insensitive match).
limit: Maximum number of records.
Returns:
List of analysis dicts.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
LIMIT %s
""",
(company_name, limit),
)
return [dict(row) for row in cursor.fetchall()]
# API Key Methods
def create_api_key(
self,
user_id: int,
key_hash: str,
label: Optional[str] = None,
) -> Dict:
"""Store a new API key hash for a user.
Args:
user_id: The owning user's ID
key_hash: bcrypt hash of the plaintext key
label: Optional human-readable label
Returns:
Dict with id, user_id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO api_keys (user_id, key_hash, label)
VALUES (%s, %s, %s)
RETURNING id, user_id, label, created_at
""",
(user_id, key_hash, label),
)
row = cursor.fetchone()
conn.commit()
return dict(row)
def list_api_keys(self, user_id: int) -> List[Dict]:
"""List active API key metadata for a user (no secrets).
Args:
user_id: The user's ID
Returns:
List of dicts with id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, label, created_at FROM api_keys WHERE user_id = %s ORDER BY created_at DESC",
(user_id,),
)
return [dict(row) for row in cursor.fetchall()]
def delete_api_key(self, key_id: int, user_id: int) -> bool:
"""Revoke an API key by ID (must belong to user).
Args:
key_id: The API key row ID
user_id: The owning user's ID
Returns:
True if a key was deleted
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM api_keys WHERE id = %s AND user_id = %s",
(key_id, user_id),
)
deleted = cursor.rowcount > 0
conn.commit()
return deleted
def get_all_api_key_hashes(self) -> List[Dict]:
"""Return all API key hashes with their associated user IDs.
Used by the auth layer to validate an incoming API key.
Returns:
List of dicts with key_hash, user_id
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT key_hash, user_id FROM api_keys")
return [dict(row) for row in cursor.fetchall()]
+319
View File
@@ -0,0 +1,319 @@
"""Tests for user-level API key generation, listing, revocation, and authentication.
Covers all acceptance criteria from issue #1673:
1. Users can create API keys (POST /auth/apikeys)
2. Users can list their active key IDs (GET /auth/apikeys)
3. Users can revoke keys (DELETE /auth/apikeys/{key_id})
4. API requests authenticated with a valid API key work on protected endpoints
5. Revoked keys are immediately rejected
6. Plaintext key is shown only at creation time
All tests use mocked DB fixtures and require no live database.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import (
create_access_token,
generate_api_key,
hash_api_key,
verify_api_key,
)
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user():
return {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth and api endpoints."""
db = MagicMock()
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
db.create_api_key.return_value = None
db.list_api_keys.return_value = []
db.delete_api_key.return_value = False
db.get_all_api_key_hashes.return_value = []
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
class TestCreateApiKey:
"""POST /auth/apikeys"""
def test_create_key_returns_plaintext_and_id(self, client, mock_db):
"""Creating a key returns the plaintext key and metadata."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 42,
"user_id": user["id"],
"label": "my-ci-key",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
json={"label": "my-ci-key"},
headers=_auth_header(user),
)
assert response.status_code == 200
data = response.json()
assert data["id"] == 42
assert len(data["key"]) == 64 # 32 bytes hex = 64 chars
assert data["label"] == "my-ci-key"
assert "created_at" in data
# Verify the hash passed to DB is valid for the returned key
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[1].get("key_hash") or call_args[0][1]
assert verify_api_key(data["key"], stored_hash)
def test_create_key_without_label(self, client, mock_db):
"""Creating a key without a label should work."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 1,
"user_id": user["id"],
"label": None,
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
headers=_auth_header(user),
)
assert response.status_code == 200
assert response.json()["label"] is None
def test_create_key_requires_auth(self, client):
"""Creating a key without auth should fail."""
response = client.post("/auth/apikeys")
assert response.status_code == 401
class TestListApiKeys:
"""GET /auth/apikeys"""
def test_list_keys_returns_metadata_only(self, client, mock_db):
"""Listing keys should return IDs and labels, not secrets."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = [
{"id": 1, "label": "key-1", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc)},
{"id": 2, "label": None, "created_at": datetime(2025, 2, 1, tzinfo=timezone.utc)},
]
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["id"] == 1
assert data[0]["label"] == "key-1"
# Ensure no secret key is exposed
for item in data:
assert "key" not in item
assert "key_hash" not in item
def test_list_keys_empty(self, client, mock_db):
"""User with no keys gets an empty list."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = []
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
assert response.json() == []
class TestRevokeApiKey:
"""DELETE /auth/apikeys/{key_id}"""
def test_revoke_existing_key(self, client, mock_db):
"""Revoking an owned key should succeed."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = True
response = client.delete("/auth/apikeys/42", headers=_auth_header(user))
assert response.status_code == 200
assert "revoked" in response.json()["message"].lower()
mock_db.delete_api_key.assert_called_once_with(42, user["id"])
def test_revoke_nonexistent_key_returns_404(self, client, mock_db):
"""Revoking a key that doesn't exist (or isn't owned) returns 404."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = False
response = client.delete("/auth/apikeys/999", headers=_auth_header(user))
assert response.status_code == 404
class TestApiKeyAuthentication:
"""Using X-API-Key header on protected endpoints."""
def test_valid_api_key_accesses_protected_endpoint(self, client, mock_db):
"""A valid API key should authenticate and access /auth/me."""
user = _make_user()
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": user["id"]},
]
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 200
data = response.json()
assert data["email"] == user["email"]
assert data["id"] == user["id"]
def test_invalid_api_key_returns_401(self, client, mock_db):
"""An invalid API key should return 401."""
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "bad-key"})
assert response.status_code == 401
assert "invalid api key" in response.json()["detail"].lower()
def test_revoked_key_returns_401(self, client, mock_db):
"""After revocation, using the key should return 401."""
# Simulate revoked key: no matching hashes in DB
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "a" * 64})
assert response.status_code == 401
def test_api_key_for_deleted_user_returns_401(self, client, mock_db):
"""An API key whose user no longer exists should return 401."""
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": 999},
]
mock_db.get_user_by_id.return_value = None # user deleted
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 401
def test_no_auth_at_all_returns_401(self, client, mock_db):
"""No auth header at all should return 401."""
response = client.get("/auth/me")
assert response.status_code == 401
class TestApiKeyFullFlow:
"""End-to-end flow: create key, use it, revoke it, try again."""
def test_create_use_revoke_flow(self, client, mock_db):
"""Simulate full lifecycle of an API key."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
# Step 1: Create key
mock_db.create_api_key.return_value = {
"id": 10,
"user_id": user["id"],
"label": "test",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
create_resp = client.post(
"/auth/apikeys",
json={"label": "test"},
headers=_auth_header(user),
)
assert create_resp.status_code == 200
plaintext = create_resp.json()["key"]
# Capture the hash that was stored
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[0][1]
# Step 2: Use key on protected endpoint
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": stored_hash, "user_id": user["id"]},
]
use_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert use_resp.status_code == 200
assert use_resp.json()["email"] == user["email"]
# Step 3: Revoke key
mock_db.delete_api_key.return_value = True
revoke_resp = client.delete("/auth/apikeys/10", headers=_auth_header(user))
assert revoke_resp.status_code == 200
# Step 4: Try using revoked key
mock_db.get_all_api_key_hashes.return_value = [] # key removed from DB
rejected_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert rejected_resp.status_code == 401
class TestApiKeyHelpers:
"""Unit tests for key generation and hashing helpers."""
def test_generate_api_key_length(self):
"""Generated key should be 64 hex characters (32 bytes)."""
key = generate_api_key()
assert len(key) == 64
# Should be valid hex
int(key, 16)
def test_generate_api_key_uniqueness(self):
"""Two generated keys should be different."""
k1 = generate_api_key()
k2 = generate_api_key()
assert k1 != k2
def test_hash_and_verify(self):
"""hash_api_key and verify_api_key should round-trip correctly."""
key = generate_api_key()
hashed = hash_api_key(key)
assert verify_api_key(key, hashed)
assert not verify_api_key("wrong-key", hashed)