Compare commits

..

30 Commits

Author SHA1 Message Date
agent-company 144d0fdf6a Add historical analysis diffing for same-company runs
- Add GET /analyze/{company_name}/diff endpoint with from/to query params
- Add GET /analyze/{company_name}/history endpoint for run selection
- Add database methods get_analysis_by_id and list_company_analyses
- Add frontend HistoryDiff page with run selector and diff visualization
- Add Compare with previous button on Analysis results page
- Add navigation link in Layout sidebar
- Add 11 tests covering helpers, happy-path, and 404 scenarios

Closes leeworks-agents/SPARC#1671

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:43:13 +00:00
agent-company e9ad97d1e8 Add rate limiting dashboard to admin panel
- Enhance GET /admin/rate-limits to return per-IP breakdown, 24h throttled
  count, and hourly time-series of rejected requests
- Add AdminRateLimits React page with auto-refresh (configurable interval),
  summary cards, throttled-over-time bar chart, endpoint table, and per-IP
  breakdown table
- Add TypeScript types (RateLimitStatsResponse, etc.) and adminApi.getRateLimits()
- Wire up /admin/rate-limits route and nav link (admin-only)
- Expand unit tests: auth, empty state, per-IP, throttled_24h, time-series,
  response shape contract (10 tests total)

Closes leeworks-agents/SPARC#1686

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:30:23 +00:00
agent-company 3d8922366e Add user-level API key generation for programmatic access
- Add api_keys table (id, user_id, key_hash, label, created_at) to schema
- Add POST /auth/apikeys to generate 32-byte hex API keys (bcrypt-hashed)
- Add GET /auth/apikeys to list active key metadata (no secrets)
- Add DELETE /auth/apikeys/{key_id} to revoke keys
- Extend get_current_user to accept either JWT Bearer or X-API-Key header
- Plaintext key returned only at creation time
- 16 new tests covering creation, listing, revocation, auth, and full flow

Closes leeworks-agents/SPARC#1673

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:18:34 +00:00
AI-Manager 313800215c Merge pull request 'Add rate limit stats to admin panel' (#1682) from feature/1675-rate-limit-admin into main
Merge PR #1682
2026-05-19 00:12:56 +00:00
AI-Manager 222f29deb1 Merge pull request 'Add cursor-based pagination to /analyze/batch and /jobs' (#1681) from feature/1669-cursor-pagination into main
Merge PR #1681
2026-05-19 00:12:48 +00:00
AI-Manager e6d95bbf57 Merge pull request 'Add stricter input validation for company names' (#1680) from feature/1670-company-name-validation into main
Merge PR #1680
2026-05-19 00:12:42 +00:00
AI-Manager 68484ef4b1 Merge pull request 'Update ROADMAP.md: mark completed P1 and P2 items as done' (#1679) from feature/1678-update-roadmap into main
Merge PR #1679
2026-05-19 00:12:34 +00:00
agent-company a0cb9a5773 Add rate limit status and usage statistics to admin panel
Add GET /admin/rate-limits endpoint (admin-only) that returns current
rate limit configuration and request statistics for all rate-limited
endpoints (/auth/register and /auth/login). Tracks total requests and
rejection counts via in-memory counters.

Includes tests for admin access, non-admin rejection, empty state,
request tracking, and configuration display.

Closes leeworks-agents/SPARC#1675

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:53:01 +00:00
agent-company 857b3444df Add cursor-based pagination to GET /analyze/batch and update /jobs defaults
Add a new GET /analyze/batch endpoint that returns stored analysis results
with cursor-based pagination (default limit 50, max 200). Also update the
existing /jobs endpoint defaults from limit=10/max=100 to limit=50/max=200
for consistency.

The database layer gains a list_analyses() method with cursor support using
(timestamp, id) ordering, matching the existing list_jobs() pattern.

Includes tests for pagination behavior, boundary limits, cursor forwarding,
company name filtering, and empty result sets.

Closes leeworks-agents/SPARC#1669

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:49:22 +00:00
agent-company a95129904e Add stricter input validation for company names on analysis endpoints
Add a CompanyName validated type enforcing 2-100 character length and
allowing only alphanumeric characters, spaces, hyphens, ampersands, and
periods. Applied to all endpoints accepting company names: /analyze,
/analyze/patent, /analyze/batch, /admin/tracked, and /export.

Includes unit tests covering too-short, too-long, special character,
leading-character, and valid edge cases for both single and batch
endpoints.

Closes leeworks-agents/SPARC#1670

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:38:44 +00:00
agent-company 7c6eed8d72 Update ROADMAP.md to mark completed P1 and P2 items as done
Move seven completed items from the P1 and P2 sections into the
Completed section: in-memory jobs persistence, export endpoint tests,
tracked company admin tests, webhook integration tests, S3 storage
tests, auto-download path tests, and scheduler DatabaseClient refactor.

The P2 section now only lists the two genuinely open items: cursor-based
pagination (Issue #1669) and request validation (Issue #1670).

Closes leeworks-agents/SPARC#1678

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:29:14 +00:00
AI-Manager 4c411e1e0b Merge pull request 'Add tests for tracked company admin endpoints and scheduler' (#1667) from feature/1656-tracked-company-admin-tests into main
Merge: Add tests for tracked company admin endpoints and scheduler integration

Closes #1656
2026-04-20 23:05:57 +00:00
agent-company 6165d66760 Fix scheduler tests to use get_db_client after scheduler refactor
The scheduler was refactored (PR #1665) to use the pooled
get_db_client() from SPARC.auth instead of creating its own
DatabaseClient. Update test mocks accordingly and remove the
db.close() assertion since the pooled client is no longer closed
by the scheduler.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 23:05:42 +00:00
agent-company e610dea9a9 Merge remote-tracking branch 'origin/main' into feature/1656-tracked-company-admin-tests 2026-04-20 23:04:59 +00:00
AI-Manager b5f10d2032 Merge pull request 'Add API tests for export endpoints (CSV and PDF)' (#1668) from feature/1655-export-endpoint-tests into main
Merge: Add API tests for export endpoints (CSV and PDF)

Closes #1655
2026-04-20 23:04:23 +00:00
AI-Manager b5d8b0b344 Merge pull request 'Add webhook integration tests for retry logic and payloads' (#1666) from feature/1657-webhook-integration-tests into main
Merge: Add webhook integration tests for retry logic and payloads

Closes #1657
2026-04-20 23:04:19 +00:00
AI-Manager 1170356b2b Merge pull request 'Add S3/MinIO storage backend tests for storage.py' (#1663) from feature/1660-s3-storage-tests into main
Merge: Add S3/MinIO storage backend tests for storage.py

Closes #1660
2026-04-20 23:04:05 +00:00
AI-Manager 84341b3ec4 Merge pull request 'Add test coverage for analyze_single_patent auto-download path' (#1662) from feature/1661-analyze-single-patent-tests into main
Merge: Add test coverage for analyze_single_patent auto-download path

Closes #1661
2026-04-20 23:04:00 +00:00
AI-Manager 0639fb3649 Merge pull request 'Update ROADMAP.md to reflect completed work and add next-horizon items' (#1664) from feature/1659-update-roadmap into main
Merge: Update ROADMAP.md to reflect completed work and add next-horizon items

Closes #1659
2026-04-20 23:03:56 +00:00
AI-Manager b032bf0c90 Merge pull request 'Refactor scheduler.py to use pooled DatabaseClient' (#1665) from feature/1658-scheduler-pooled-db into main
Merge: Refactor scheduler.py to use pooled DatabaseClient

Closes #1658
2026-04-20 23:03:43 +00:00
agent-company a2f81b0396 Add test coverage for analyze_single_patent auto-download path
7 test cases covering:
- PDF on disk analyzed directly (no download)
- Auto-download from cached metadata link when PDF missing
- FileNotFoundError when no cached link available
- Cached patent without pdf_link raises FileNotFoundError
- Analysis pipeline failure returns error string gracefully
- Model override parameter forwarded to LLM
- FileNotFoundError during parsing re-raised (not swallowed)

Closes leeworks-agents/SPARC#1661

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:21:53 +00:00
agent-company 63ca18e9bf Add S3/MinIO storage backend tests for storage.py
21 test cases covering:
- S3StorageBackend: read, write, exists, path_for with mocked boto3
- Error handling: NoSuchKey exception, generic 404, non-404 re-raise
- Bucket auto-creation on init and graceful handling of creation failure
- Constructor credential/endpoint passthrough
- LocalStorageBackend: round-trip read/write, missing file, empty file
- get_storage_backend() factory: local/s3 selection, case-insensitivity

Closes leeworks-agents/SPARC#1660

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:20:06 +00:00
agent-company 4cb1a6ed21 Update ROADMAP.md to reflect completed work and add next-horizon items
Move all completed items (security hardening, structured logging, dark mode,
export, webhooks, scheduled analysis, multi-model, trend charts, CI, etc.)
into a new Completed section. Reorganize remaining P1/P2/P3 items to reflect
current priorities. Add new next-horizon items: historical diffing, patent
classification tagging, user API keys, batch export, and multi-tenant support.

Closes leeworks-agents/SPARC#1659

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:18:22 +00:00
agent-company 417b7ab31e Refactor scheduler.py to use the application-level pooled DatabaseClient
Replace the per-invocation DatabaseClient creation in
run_scheduled_analysis() with the shared pooled client from
SPARC.auth.get_db_client(). This avoids creating a new database
connection on every scheduler tick, which could exhaust the connection
pool under load.

Key changes:
- Import get_db_client from SPARC.auth instead of DatabaseClient
- Remove manual connect/initialize_schema/close calls
- Remove unused SPARC.config import

Closes leeworks-agents/SPARC#1658

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:16:54 +00:00
agent-company 2eabb1d704 Add webhook integration tests covering retry logic and Slack/Discord payloads
22 test cases covering:
- Slack/Discord URL detection
- Generic vs Slack payload formatting
- Exponential backoff retry logic with network/timeout error handling
- Multi-URL dispatch with format auto-detection
- notify_job_completed() and notify_alert() helpers

Closes leeworks-agents/SPARC#1657

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:15:34 +00:00
agent-company fc942b2aa4 Add tests for tracked company admin endpoints and scheduler integration
20 test cases covering:
- GET/POST/DELETE /admin/tracked endpoints with admin auth enforcement
- GET /admin/alerts with limit parameter and auth
- scheduler.run_scheduled_analysis() for multi-company analysis, alert
  triggering on significant patent count changes, graceful failure handling

Closes leeworks-agents/SPARC#1656

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:14:29 +00:00
agent-company 44a162056d Add API tests for export endpoints (CSV and PDF)
Covers GET /export/{company_name} and /export/{company_name}/pdf with
13 test cases: successful export, 404 on missing data, auth enforcement,
filename sanitization, XML-special character handling in PDF, and
multi-row output validation.

Closes leeworks-agents/SPARC#1655

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 19:11:42 +00:00
AI-Manager a07a0c7fbe Merge pull request 'Fix remaining dark mode issue in Analysis page prose block' (#1628) from feature/1605-dark-mode into main
Fix remaining dark mode issue in Analysis page prose block (#1628)
2026-04-20 06:41:59 +00:00
AI-Manager 43fd2c9575 Merge pull request 'Expand JWT auth integration tests to 33 cases' (#1627) from feature/1624-jwt-auth-tests into main
Expand JWT auth integration tests to 33 cases (#1627)
2026-04-20 06:41:47 +00:00
agent-company 2f2b6382fa Expand JWT auth integration tests from 17 to 33 cases
Add comprehensive edge-case coverage for issue #1624:

- Admin delete user endpoint (5 tests): successful delete, self-delete
  prevention, nonexistent user 404, non-admin 403, missing token rejection
- Admin role change gaps (2 tests): nonexistent user 404, non-admin 403
- Input validation (3 tests): invalid email 422, short password 422,
  missing fields 422 for both register and login
- Token edge cases (4 tests): malformed token, wrong-secret token,
  deleted user token, deleted user refresh
- Token claim verification (1 test): login tokens contain correct claims

All tests use mocked DB fixtures and require no live database.

Closes leeworks-agents/SPARC#1624

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-20 06:05:54 +00:00
22 changed files with 4094 additions and 124 deletions
+120 -85
View File
@@ -7,86 +7,124 @@ Semiconductor Patent & Analytics Report Core -- development priorities.
SPARC is a patent analysis platform with a working end-to-end pipeline:
Python/FastAPI backend, React/TypeScript frontend, PostgreSQL for persistence
and caching, Docker Compose for local development, and Gitea Actions CI/CD for
image builds. Core features (patent retrieval via SerpAPI, PDF parsing, LLM
analysis via OpenRouter/Claude, batch processing, JWT authentication, analytics
dashboard) are all implemented and functional.
image builds and testing. Core features include patent retrieval via SerpAPI,
PDF parsing, LLM analysis via OpenRouter (multi-model: Claude, GPT-4o, Gemini,
Llama), batch processing, JWT authentication, analytics dashboard with patent
trend charts, scheduled recurring analysis with alerting, webhook notifications
(Slack/Discord), CSV and PDF export, S3/MinIO storage, side-by-side company
comparison, and dark mode.
---
## Completed
Items that have been implemented and merged into main.
### Security hardening
- ~~Rotate default JWT secret.~~ Startup check refuses to start with the
default secret in non-development environments.
- ~~CORS allow-origins are hardcoded.~~ Allowed origins are now configurable
via environment variable.
- ~~Database credentials in docker-compose.yml.~~ Compose references `.env`
for sensitive values.
### Error handling and resilience
- ~~`get_db_client()` creates a new `DatabaseClient` on every call.~~ Refactored
to a shared pooled singleton initialized at startup.
- ~~No rate limiting on auth endpoints.~~ Rate limiting middleware added to
`/auth/login` and `/auth/register`.
### Test coverage
- ~~API tests bypass authentication.~~ JWT auth integration tests added (33
cases covering registration, login, protected routes, token refresh, and
admin-only endpoints).
- ~~No test stage in CI.~~ Gitea Actions workflow now runs `pytest` and gates
the build.
- ~~No linting or type checking in CI.~~ `ruff` (Python) and `tsc --noEmit`
(TypeScript) added to CI pipeline.
### Backend
- ~~Add structured logging.~~ Python `logging` module used throughout.
- ~~Make LLM model configurable.~~ `MODEL` environment variable accepted;
multi-model support with per-analysis selection (GPT-4o, Gemini, Claude,
Llama).
- ~~SERP cache TTL hardcoded.~~ `SERP_CACHE_TTL_HOURS` exposed as env var.
- ~~Patent PDF storage.~~ S3/MinIO object storage backend added alongside
local filesystem. Volume mount requirement documented.
- ~~`analyze_single_patent` assumes local file.~~ Auto-download from cached
metadata link integrated.
- ~~`Patent.patent_id` typed as `int`.~~ Fixed to `str`.
### Frontend
- ~~No loading/error states.~~ Skeleton loaders and error states added to
Batch and Analytics pages.
- ~~No dark mode.~~ Full dark mode support with theme-aware chart colors.
- ~~Missing lockfile.~~ `package-lock.json` committed.
### Features (formerly P3)
- ~~Export analysis reports.~~ CSV and PDF export endpoints implemented.
- ~~Comparison view.~~ Side-by-side company patent portfolio comparison added.
- ~~Scheduled/recurring analysis.~~ APScheduler-based periodic re-analysis
with configurable interval and change-threshold alerting.
- ~~Webhook/notification support.~~ Slack, Discord, and generic HTTP POST
webhooks with retry logic.
- ~~Multi-model support.~~ Model picker in Analysis and Batch pages; backend
allow-list validation.
- ~~Patent trend charts.~~ Filing frequency and category distribution
visualizations added to Analytics page.
- ~~OpenAPI client generation.~~ TypeScript API client auto-generated from
FastAPI spec with CI freshness check.
### Resilience
- ~~`_jobs` dict is in-memory only.~~ Database-backed job persistence
implemented using `db.list_jobs()` and `mark_stale_jobs_failed()`. The
in-memory `_jobs` dict has been removed.
### Test coverage (P1/P2)
- ~~Export endpoint tests.~~ Tests added for CSV and PDF export endpoints.
- ~~Tracked company admin endpoint tests.~~ Tests added for `/admin/tracked`
CRUD endpoints and scheduler integration.
- ~~Webhook integration tests.~~ Tests added for retry logic, Slack/Discord
payload format, and multi-URL dispatch.
- ~~S3/MinIO storage backend tests.~~ Unit tests added for the S3 backend
(read, write, exists, delete, error handling).
- ~~`analyze_single_patent` auto-download path tests.~~ Tests added for the
auto-download fallback (cache lookup, PDF download, FileNotFoundError).
### Code quality
- ~~Scheduler creates its own DatabaseClient.~~ Refactored to use the
application-level pooled `get_db_client()`.
---
## P1 -- High Priority
These items address correctness, security, and reliability gaps that should be
resolved before broader production use.
### Security hardening
- **Rotate default JWT secret.** `auth.py` ships a fallback
`sparc-secret-key-change-in-production` that will be used if `JWT_SECRET` is
unset. Add a startup check that refuses to start with the default secret in
non-development environments.
- **CORS allow-origins are hardcoded.** `api.py` only permits
`localhost:3000` and `localhost:5173`. Make the allowed origins configurable
via environment variable so the dashboard works when deployed behind a real
domain.
- **Database credentials in docker-compose.yml.** The compose file embeds
`postgres:postgres` in plain text. Reference a `.env` file or Docker secrets
instead.
### Error handling and resilience
- **`get_db_client()` in `auth.py` creates a new `DatabaseClient` on every
call.** This bypasses the connection pool and can exhaust database
connections under load. Refactor to share a single pooled client.
- **`_jobs` dict is in-memory only.** Job state is lost on API restart. Persist
job status in PostgreSQL or Redis so async batch results survive restarts.
- **No rate limiting on auth endpoints.** `/auth/login` and `/auth/register`
are unprotected against brute-force or abuse. Add rate limiting middleware.
### Test coverage for auth and admin
- The existing API tests (`tests/test_api.py`) bypass authentication entirely.
Add tests that exercise the JWT flow: registration, login, protected-route
access, token refresh, and admin-only endpoints.
No outstanding P1 items. All previously listed items have been completed and
moved to the Completed section above.
---
## P2 -- Medium Priority
Improvements to usability, performance, and developer experience.
Improvements to the API surface.
### Backend
### API improvements
- **Add structured logging.** Replace `print()` calls throughout `analyzer.py`,
`serp_api.py`, and `llm.py` with Python `logging` so log levels and
formatting are consistent.
- **Make LLM model configurable.** `llm.py` hardcodes
`anthropic/claude-3.5-sonnet`. Accept a `MODEL` environment variable to allow
switching models without code changes.
- **SERP cache TTL is hardcoded to 24 hours.** Expose `SERP_CACHE_TTL_HOURS`
as an environment variable in `config.py`.
- **Patent PDF storage.** PDFs are saved to a local `patents/` directory. For
containerized deployments, consider object storage (S3/MinIO) or at minimum
document the volume mount requirement more prominently.
- **`analyze_single_patent` assumes local file path.** The method constructs
`patents/{patent_id}.pdf` and reads from disk, but does not download the PDF
first. Either integrate the download step or document the prerequisite.
- **`Patent.patent_id` typed as `int` in `types.py` but used as `str`
everywhere.** Fix the type annotation to `str`.
### Frontend
- **No loading/error states on several pages.** The Batch and Analytics pages
would benefit from skeleton loaders and user-friendly error messages.
- **No dark mode.** Tailwind is configured but no dark variant is applied.
- **Missing `package-lock.json` or `pnpm-lock.yaml`.** The frontend has no
lockfile committed, leading to non-reproducible builds.
### CI/CD
- **No test stage in the Gitea Actions workflow.** `build.yaml` builds and
pushes images but never runs `pytest`. Add a test job that gates the build.
- **No linting or type checking.** Add `ruff` (Python) and `tsc --noEmit`
(TypeScript) to CI.
- **API pagination.** The `/analyze/batch` endpoint needs cursor-based
pagination for large result sets. The `/jobs` endpoint already has cursor
pagination. *(Issue #1669)*
- **Request validation improvements.** Add stricter input validation for
company names (disallow special characters, enforce length limits).
*(Issue #1670)*
---
@@ -94,23 +132,20 @@ Improvements to usability, performance, and developer experience.
Lower-urgency enhancements and future features.
- **Export analysis reports.** Allow users to download analysis results as PDF
or CSV from the dashboard.
- **Comparison view.** Side-by-side comparison of two companies' patent
portfolios.
- **Scheduled/recurring analysis.** Periodically re-analyze tracked companies
and alert on significant changes.
- **Webhook/notification support.** Send alerts (Slack, Discord, email) when
batch jobs complete or when a company's innovation score changes
significantly.
- **Multi-model support.** Let users choose between LLM providers per analysis
(e.g., GPT-4o, Gemini, Claude) and compare outputs.
- **Patent trend charts.** Visualize patent filing frequency and technology
category distribution over time in the Analytics page.
- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit
from cursor-based pagination for large result sets.
- **OpenAPI client generation.** Auto-generate the TypeScript API client from
the FastAPI OpenAPI spec to keep frontend types in sync.
- **Historical analysis diffing.** Show what changed between two analysis runs
for the same company, highlighting new patents and score shifts.
- **Patent classification tagging.** Automatically tag patents by technology
domain (AI, semiconductors, materials science) using LLM classification.
- **User-level API keys.** Allow users to generate personal API keys for
programmatic access without JWT token refresh.
- **Batch export.** Export analysis results for multiple companies at once as
a ZIP archive.
- **Rate limiting dashboard.** Surface rate limit status and usage statistics
in the admin panel.
- **Async webhook delivery.** Move webhook delivery to a background task queue
(e.g., Celery, arq) to avoid blocking the scheduler.
- **Multi-tenant support.** Scope analysis results and tracked companies per
user or organization.
---
+446 -11
View File
@@ -5,17 +5,18 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING:
from SPARC.database import DatabaseClient
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field
from pydantic import BaseModel, EmailStr, Field, StringConstraints
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
@@ -29,13 +30,25 @@ from SPARC.auth import (
close_db_client,
create_tokens,
decode_token,
generate_api_key,
get_current_admin,
get_current_user,
get_db_client,
hash_api_key,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[
str,
StringConstraints(
min_length=2,
max_length=100,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
),
]
# Pydantic models for API
class CompanyAnalysisResponse(BaseModel):
@@ -72,7 +85,7 @@ class CompanyAnalysisRequest(BaseModel):
class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis."""
companies: list[str] = Field(
companies: list[CompanyName] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze"
)
max_workers: int = Field(
@@ -96,6 +109,24 @@ class JobStatus(BaseModel):
error: str | None = None
class AnalysisRecord(BaseModel):
"""A single stored analysis result."""
id: int
company_name: str | None = None
analysis_type: str | None = None
model: str | None = None
response: str | None = None
timestamp: datetime | None = None
class PaginatedAnalysisResponse(BaseModel):
"""Paginated response for analysis result listings."""
items: list[AnalysisRecord]
next_cursor: str | None = None
class PaginatedJobsResponse(BaseModel):
"""Paginated response for job listings."""
@@ -111,6 +142,31 @@ class HealthResponse(BaseModel):
timestamp: datetime
# Historical diff models
class AnalysisDiffResponse(BaseModel):
"""Response model for diffing two analysis runs of the same company."""
company_name: str
from_id: int
to_id: int
from_timestamp: datetime
to_timestamp: datetime
patent_count_delta: int
added_patents: list[str]
removed_patents: list[str]
changed_fields: dict[str, dict]
summary: str
class CompanyAnalysisHistoryItem(BaseModel):
"""A summary item from a company's analysis history."""
id: int
analysis_type: str | None = None
model: str | None = None
timestamp: datetime
# Auth request/response models
class RegisterRequest(BaseModel):
"""User registration request."""
@@ -217,10 +273,45 @@ app = FastAPI(
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
key = endpoint
if key not in _rate_limit_stats:
_rate_limit_stats[key] = {
"endpoint": endpoint,
"total_requests": 0,
"rejected_requests": 0,
"by_ip": {},
}
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
ip_stats[ip]["total"] += 1
if rejected:
ip_stats[ip]["rejected"] += 1
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
endpoint = request.url.path
ip = get_remote_address(request)
_track_rate_limit_request(endpoint, ip, rejected=True)
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
@@ -249,6 +340,7 @@ async def register(request: Request, body: RegisterRequest):
The first registered user automatically becomes an admin.
"""
_track_rate_limit_request("/auth/register", get_remote_address(request))
db = get_db_client()
# First user becomes admin
@@ -279,6 +371,7 @@ async def register(request: Request, body: RegisterRequest):
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens."""
_track_rate_limit_request("/auth/login", get_remote_address(request))
db = get_db_client()
user = db.authenticate_user(body.email, body.password)
@@ -321,6 +414,92 @@ async def get_me(current_user: UserResponse = Depends(get_current_user)):
return current_user
# ============== API Key Endpoints ==============
class CreateApiKeyRequest(BaseModel):
"""Request to create a new API key."""
label: str | None = Field(default=None, max_length=100, description="Optional label for the key")
class ApiKeyResponse(BaseModel):
"""Response after creating an API key (includes plaintext key)."""
id: int
key: str # plaintext key, shown only at creation time
label: str | None = None
created_at: datetime
class ApiKeyInfo(BaseModel):
"""API key metadata (no secret)."""
id: int
label: str | None = None
created_at: datetime
@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"])
async def create_api_key_endpoint(
body: CreateApiKeyRequest | None = None,
current_user: UserResponse = Depends(get_current_user),
):
"""Generate a new API key for the authenticated user.
The plaintext key is returned **only once** in the response.
Store it securely; it cannot be retrieved again.
"""
plaintext_key = generate_api_key()
key_hash = hash_api_key(plaintext_key)
db = get_db_client()
label = body.label if body else None
row = db.create_api_key(
user_id=current_user.id,
key_hash=key_hash,
label=label,
)
return ApiKeyResponse(
id=row["id"],
key=plaintext_key,
label=row["label"],
created_at=row["created_at"],
)
@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"])
async def list_api_keys_endpoint(
current_user: UserResponse = Depends(get_current_user),
):
"""List active API key IDs and labels for the authenticated user.
Does **not** return the secret keys.
"""
db = get_db_client()
keys = db.list_api_keys(current_user.id)
return [ApiKeyInfo(**k) for k in keys]
@app.delete("/auth/apikeys/{key_id}", tags=["Auth"])
async def revoke_api_key_endpoint(
key_id: int,
current_user: UserResponse = Depends(get_current_user),
):
"""Revoke (delete) an API key by its ID.
The key must belong to the authenticated user.
"""
db = get_db_client()
deleted = db.delete_api_key(key_id, current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="API key not found")
return {"message": "API key revoked"}
# ============== Admin Endpoints ==============
@@ -405,7 +584,7 @@ async def delete_user(
class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255)
company_name: CompanyName = Field(...)
@app.get("/admin/tracked", tags=["Admin"])
@@ -432,7 +611,7 @@ async def add_tracked_company(
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
@@ -443,6 +622,69 @@ async def remove_tracked_company(
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/rate-limits", tags=["Admin"])
async def get_rate_limit_stats(
_: UserResponse = Depends(get_current_admin),
):
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns:
Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
"/auth/login": {"limit": "10/minute"},
}
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
})
# Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"])
async def list_alerts(
limit: int = Query(default=50, ge=1, le=200),
@@ -590,7 +832,7 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
@@ -642,7 +884,7 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
@@ -816,7 +1058,7 @@ async def health_check():
tags=["Analysis"],
)
async def analyze_company(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
@@ -840,13 +1082,154 @@ async def analyze_company(
return _convert_result(result)
def _extract_patent_ids(response_text: str) -> set[str]:
"""Extract patent IDs from an analysis response text.
Looks for patterns like US-12345678-B2, US12345678B2, etc.
"""
import re
pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?"
return set(re.findall(pattern, response_text or ""))
def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse:
"""Compute a structured diff between two analysis records."""
from_patents = _extract_patent_ids(from_rec.get("response", "") or "")
to_patents = _extract_patent_ids(to_rec.get("response", "") or "")
added = sorted(to_patents - from_patents)
removed = sorted(from_patents - to_patents)
patent_count_delta = len(to_patents) - len(from_patents)
changed_fields: dict[str, dict] = {}
if from_rec.get("model") != to_rec.get("model"):
changed_fields["model"] = {
"from": from_rec.get("model"),
"to": to_rec.get("model"),
}
if from_rec.get("analysis_type") != to_rec.get("analysis_type"):
changed_fields["analysis_type"] = {
"from": from_rec.get("analysis_type"),
"to": to_rec.get("analysis_type"),
}
# Build a human-readable summary
parts: list[str] = []
if added:
parts.append(f"{len(added)} new patent(s) appeared")
if removed:
parts.append(f"{len(removed)} patent(s) no longer referenced")
if patent_count_delta > 0:
parts.append(f"patent mention count increased by {patent_count_delta}")
elif patent_count_delta < 0:
parts.append(f"patent mention count decreased by {abs(patent_count_delta)}")
if changed_fields:
parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}")
summary = "; ".join(parts) if parts else "No significant differences detected."
return AnalysisDiffResponse(
company_name=to_rec["company_name"],
from_id=from_rec["id"],
to_id=to_rec["id"],
from_timestamp=from_rec["timestamp"],
to_timestamp=to_rec["timestamp"],
patent_count_delta=patent_count_delta,
added_patents=added,
removed_patents=removed,
changed_fields=changed_fields,
summary=summary,
)
@app.get(
"/analyze/{company_name}/history",
response_model=list[CompanyAnalysisHistoryItem],
tags=["Analysis"],
)
async def list_company_analysis_history(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
limit: int = Query(default=20, ge=1, le=100),
_: UserResponse = Depends(get_current_user),
):
"""List previous analysis runs for a company.
Returns a list of analysis records ordered by timestamp descending,
useful for selecting which runs to compare via the diff endpoint.
Args:
company_name: Company name to look up
limit: Maximum number of results
Returns:
List of analysis history items
"""
db = _get_job_db()
rows = db.list_company_analyses(company_name, limit=limit)
return [
CompanyAnalysisHistoryItem(
id=r["id"],
analysis_type=r.get("analysis_type"),
model=r.get("model"),
timestamp=r["timestamp"],
)
for r in rows
]
@app.get(
"/analyze/{company_name}/diff",
response_model=AnalysisDiffResponse,
tags=["Analysis"],
)
async def diff_company_analyses(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
from_id: int = Query(..., alias="from", description="Analysis ID of the older run"),
to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"),
_: UserResponse = Depends(get_current_user),
):
"""Compare two analysis runs for the same company.
Returns a structured diff showing added/removed patents, score delta,
and a summary narrative.
Args:
company_name: Company name (must match both analysis records)
from_id: ID of the older analysis run
to_id: ID of the newer analysis run
Returns:
AnalysisDiffResponse with added/removed/changed fields
Raises:
404: If either analysis ID does not exist or belongs to a different company
"""
db = _get_job_db()
from_rec = db.get_analysis_by_id(from_id)
if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {from_id} not found for company '{company_name}'",
)
to_rec = db.get_analysis_by_id(to_id)
if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {to_id} not found for company '{company_name}'",
)
return _compute_analysis_diff(from_rec, to_rec)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: str = Query(description="Company name for analysis context"),
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
@@ -872,6 +1255,58 @@ async def analyze_single_patent(
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/batch",
response_model=PaginatedAnalysisResponse,
tags=["Analysis"],
)
async def list_analysis_results(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List stored analysis results with cursor-based pagination.
Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
parameter to fetch the next page. When ``next_cursor`` is ``null``,
there are no more results.
Args:
company_name: Optional filter by company name
limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of analysis results
"""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
@@ -1047,7 +1482,7 @@ async def list_jobs(
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 10,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
+98 -9
View File
@@ -1,11 +1,13 @@
"""JWT authentication utilities for SPARC API."""
"""JWT and API key authentication utilities for SPARC API."""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
@@ -32,7 +34,7 @@ def check_jwt_secret() -> None:
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer()
security = HTTPBearer(auto_error=False)
class TokenPayload(BaseModel):
@@ -178,20 +180,107 @@ def get_db_client() -> DatabaseClient:
return _db_client
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token.
def generate_api_key() -> str:
"""Generate a random 32-byte hex API key.
Returns:
64-character hex string
"""
return secrets.token_hex(32)
def hash_api_key(key: str) -> str:
"""Hash an API key using bcrypt.
Args:
credentials: Bearer token from request
key: Plaintext API key
Returns:
bcrypt hash string
"""
return bcrypt.hashpw(key.encode(), bcrypt.gensalt()).decode()
def verify_api_key(key: str, key_hash: str) -> bool:
"""Verify a plaintext API key against its bcrypt hash.
Args:
key: Plaintext API key
key_hash: Stored bcrypt hash
Returns:
True if key matches
"""
return bcrypt.checkpw(key.encode(), key_hash.encode())
def _authenticate_via_api_key(api_key: str) -> Optional[UserResponse]:
"""Look up a user by raw API key.
Iterates over all stored key hashes (small table) and returns the
corresponding user when a match is found.
Args:
api_key: Plaintext API key from X-API-Key header
Returns:
UserResponse if valid key, None otherwise
"""
db = get_db_client()
key_rows = db.get_all_api_key_hashes()
for row in key_rows:
if verify_api_key(api_key, row["key_hash"]):
user = db.get_user_by_id(row["user_id"])
if user:
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
return None
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
) -> UserResponse:
"""Get the current authenticated user from JWT token or API key.
Supports two authentication methods:
1. Bearer JWT token via Authorization header
2. API key via X-API-Key header
Args:
request: The incoming request (used for X-API-Key header)
credentials: Optional Bearer token from request
Returns:
UserResponse with user details
Raises:
HTTPException: If token is invalid or expired
HTTPException: If no valid credentials are provided
"""
# Try X-API-Key header first
api_key = request.headers.get("X-API-Key")
if api_key:
user = _authenticate_via_api_key(api_key)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
# Fall back to JWT Bearer token
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = decode_token(token)
+198
View File
@@ -221,6 +221,27 @@ class DatabaseClient:
ON alerts(company_name)
""")
# Create API keys table for programmatic access
cursor.execute("""
CREATE TABLE IF NOT EXISTS api_keys (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_hash VARCHAR(255) NOT NULL,
label VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
ON api_keys(user_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash
ON api_keys(key_hash)
""")
self.conn.commit()
@staticmethod
@@ -371,6 +392,48 @@ class DatabaseClient:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def list_analyses(
self,
company_name: Optional[str] = None,
limit: int = 50,
cursor: Optional[str] = None,
) -> List[Dict]:
"""List analysis results with cursor-based pagination.
Args:
company_name: Optional filter by company name.
limit: Maximum number of records to return.
cursor: Opaque cursor (``timestamp|id``) from a previous response.
Returns:
List of analysis dicts ordered by timestamp descending.
"""
conditions: list[str] = ["is_cached = FALSE"]
params: list = []
if company_name:
conditions.append("LOWER(company_name) = LOWER(%s)")
params.append(company_name)
if cursor:
try:
ts_str, cursor_id = cursor.rsplit("|", 1)
conditions.append("(timestamp, id) < (%s, %s)")
params.extend([ts_str, int(cursor_id)])
except (ValueError, TypeError):
pass # Ignore malformed cursors; return from start
query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages"
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY timestamp DESC, id DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, params)
return [dict(row) for row in cur.fetchall()]
def get_analytics(self, days: int = 30) -> Dict:
"""Get analytics on message usage.
@@ -935,3 +998,138 @@ class DatabaseClient:
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
# Historical Analysis Diff Methods
def get_analysis_by_id(self, analysis_id: int) -> Optional[Dict]:
"""Get a single analysis record by its ID.
Args:
analysis_id: The primary key of the llm_messages row.
Returns:
Dict with analysis fields, or None if not found.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE id = %s AND is_cached = FALSE
""",
(analysis_id,),
)
row = cursor.fetchone()
return dict(row) if row else None
def list_company_analyses(
self, company_name: str, limit: int = 20
) -> List[Dict]:
"""List past analysis runs for a given company.
Returns records ordered by timestamp descending so callers can
identify which previous runs are available for diffing.
Args:
company_name: Company name (case-insensitive match).
limit: Maximum number of records.
Returns:
List of analysis dicts.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
LIMIT %s
""",
(company_name, limit),
)
return [dict(row) for row in cursor.fetchall()]
# API Key Methods
def create_api_key(
self,
user_id: int,
key_hash: str,
label: Optional[str] = None,
) -> Dict:
"""Store a new API key hash for a user.
Args:
user_id: The owning user's ID
key_hash: bcrypt hash of the plaintext key
label: Optional human-readable label
Returns:
Dict with id, user_id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO api_keys (user_id, key_hash, label)
VALUES (%s, %s, %s)
RETURNING id, user_id, label, created_at
""",
(user_id, key_hash, label),
)
row = cursor.fetchone()
conn.commit()
return dict(row)
def list_api_keys(self, user_id: int) -> List[Dict]:
"""List active API key metadata for a user (no secrets).
Args:
user_id: The user's ID
Returns:
List of dicts with id, label, created_at
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, label, created_at FROM api_keys WHERE user_id = %s ORDER BY created_at DESC",
(user_id,),
)
return [dict(row) for row in cursor.fetchall()]
def delete_api_key(self, key_id: int, user_id: int) -> bool:
"""Revoke an API key by ID (must belong to user).
Args:
key_id: The API key row ID
user_id: The owning user's ID
Returns:
True if a key was deleted
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM api_keys WHERE id = %s AND user_id = %s",
(key_id, user_id),
)
deleted = cursor.rowcount > 0
conn.commit()
return deleted
def get_all_api_key_hashes(self) -> List[Dict]:
"""Return all API key hashes with their associated user IDs.
Used by the auth layer to validate an incoming API key.
Returns:
List of dicts with key_hash, user_id
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT key_hash, user_id FROM api_keys")
return [dict(row) for row in cursor.fetchall()]
+12 -7
View File
@@ -2,14 +2,17 @@
Uses APScheduler to periodically re-analyze tracked companies and
detect significant changes in patent counts.
The scheduler reuses the application-level pooled DatabaseClient
(from ``SPARC.auth``) instead of creating its own connection, which
avoids exhausting the database connection pool under load.
"""
import logging
import os
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.database import DatabaseClient
from SPARC.auth import get_db_client
logger = logging.getLogger(__name__)
@@ -21,10 +24,13 @@ CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20"))
def run_scheduled_analysis() -> None:
"""Re-analyze all tracked companies and check for significant changes."""
db = DatabaseClient(config.database_url)
db.connect()
db.initialize_schema()
"""Re-analyze all tracked companies and check for significant changes.
Uses the shared pooled DatabaseClient from ``SPARC.auth.get_db_client()``
rather than creating a disposable connection, so the scheduler participates
in the same connection pool as the rest of the application.
"""
db = get_db_client()
tracked = db.list_tracked_companies()
if not tracked:
@@ -74,7 +80,6 @@ def run_scheduled_analysis() -> None:
except Exception as e:
logger.error("Error analyzing tracked company %s: %s", name, e)
db.close()
logger.info("Scheduled analysis complete")
+11
View File
@@ -11,7 +11,9 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare';
import { HistoryDiff } from './pages/HistoryDiff';
const queryClient = new QueryClient({
defaultOptions: {
@@ -45,6 +47,7 @@ function App() {
<Route path="/batch" element={<Batch />} />
<Route path="/analytics" element={<AnalyticsPage />} />
<Route path="/compare" element={<Compare />} />
<Route path="/history-diff" element={<HistoryDiff />} />
<Route path="/about" element={<About />} />
{/* Admin routes */}
@@ -56,6 +59,14 @@ function App() {
</ProtectedRoute>
}
/>
<Route
path="/admin/rate-limits"
element={
<ProtectedRoute requireAdmin>
<AdminRateLimits />
</ProtectedRoute>
}
/>
</Route>
{/* Default redirect */}
+66
View File
@@ -148,8 +148,43 @@ export const analysisApi = {
const response = await api.get<JobStatus[]>(`/jobs?${params}`);
return response.data;
},
getCompanyHistory: async (companyName: string, limit = 20): Promise<AnalysisHistoryItem[]> => {
const response = await api.get<AnalysisHistoryItem[]>(
`/analyze/${encodeURIComponent(companyName)}/history?limit=${limit}`
);
return response.data;
},
diffAnalyses: async (companyName: string, fromId: number, toId: number): Promise<AnalysisDiff> => {
const response = await api.get<AnalysisDiff>(
`/analyze/${encodeURIComponent(companyName)}/diff?from=${fromId}&to=${toId}`
);
return response.data;
},
};
// Analysis diff types
export interface AnalysisHistoryItem {
id: number;
analysis_type: string | null;
model: string | null;
timestamp: string;
}
export interface AnalysisDiff {
company_name: string;
from_id: number;
to_id: number;
from_timestamp: string;
to_timestamp: string;
patent_count_delta: number;
added_patents: string[];
removed_patents: string[];
changed_fields: Record<string, { from: string | null; to: string | null }>;
summary: string;
}
// Export API
export const exportApi = {
exportCsv: async (companyName: string): Promise<void> => {
@@ -201,6 +236,32 @@ export const analyticsApi = {
},
};
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API
export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -216,6 +277,11 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`);
},
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
};
export default api;
+3 -1
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, History, ShieldAlert } from 'lucide-react';
export function Layout() {
const { user, isAdmin, logout } = useAuth();
@@ -18,11 +18,13 @@ export function Layout() {
{ to: '/batch', icon: Layers, label: 'Batch' },
{ to: '/analytics', icon: BarChart3, label: 'Analytics' },
{ to: '/compare', icon: GitCompareArrows, label: 'Compare' },
{ to: '/history-diff', icon: History, label: 'Diff' },
{ to: '/about', icon: Info, label: 'About' },
];
if (isAdmin) {
navItems.push({ to: '/admin/users', icon: Users, label: 'Users' });
navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' });
}
return (
+240
View File
@@ -0,0 +1,240 @@
import { useState } from 'react';
import { useQuery } from '@tanstack/react-query';
import { adminApi } from '../api/client';
import type { RateLimitStatsResponse } from '../api/client';
import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react';
const REFRESH_OPTIONS = [
{ label: '15s', value: 15_000 },
{ label: '30s', value: 30_000 },
{ label: '1m', value: 60_000 },
{ label: 'Off', value: 0 },
];
export function AdminRateLimits() {
const [refreshInterval, setRefreshInterval] = useState(30_000);
const { data, isLoading, isError, dataUpdatedAt } = useQuery<RateLimitStatsResponse>({
queryKey: ['admin-rate-limits'],
queryFn: () => adminApi.getRateLimits(),
refetchInterval: refreshInterval || false,
});
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div>
);
}
if (isError) {
return (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load rate limit statistics.</span>
</div>
);
}
const maxThrottledCount = data?.throttled_over_time?.length
? Math.max(...data.throttled_over_time.map((b) => b.count))
: 0;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between flex-wrap gap-4">
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Rate Limiting Dashboard
</h2>
<p className="text-text-secondary">Monitor API rate limits and throttled requests.</p>
</div>
<div className="flex items-center gap-3">
{/* Last updated */}
{dataUpdatedAt > 0 && (
<span className="text-xs text-text-secondary flex items-center gap-1">
<Clock size={12} />
Updated {new Date(dataUpdatedAt).toLocaleTimeString()}
</span>
)}
{/* Refresh interval selector */}
<div className="flex items-center gap-1 bg-bg-card/60 border border-primary/15 rounded-xl p-1">
<RefreshCw size={14} className="text-text-secondary ml-2" />
{REFRESH_OPTIONS.map((opt) => (
<button
key={opt.value}
onClick={() => setRefreshInterval(opt.value)}
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
refreshInterval === opt.value
? 'bg-primary text-white'
: 'text-text-secondary hover:text-text-primary hover:bg-bg-card-hover'
}`}
>
{opt.label}
</button>
))}
</div>
</div>
</div>
{/* Summary cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<Activity size={18} className="text-primary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-error/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-error" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Throttled (24h)
</span>
</div>
<div className="text-3xl font-bold text-error">
{data?.throttled_24h ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-secondary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-secondary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rate-Limited Endpoints
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.length ?? 0}
</div>
</div>
</div>
{/* Throttled over time chart (simple bar chart) */}
{data?.throttled_over_time && data.throttled_over_time.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider mb-4">
Throttled Requests Over Time (Last 24h)
</h3>
<div className="flex items-end gap-1 h-32">
{data.throttled_over_time.map((bucket) => {
const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0;
const hour = new Date(bucket.timestamp).getHours();
return (
<div key={bucket.timestamp} className="flex-1 flex flex-col items-center gap-1">
<span className="text-xs text-text-secondary">{bucket.count}</span>
<div
className="w-full bg-error/70 rounded-t-sm min-h-[2px] transition-all"
style={{ height: `${Math.max(height, 2)}%` }}
title={`${bucket.timestamp}: ${bucket.count} throttled`}
/>
<span className="text-[10px] text-text-secondary">{hour}:00</span>
</div>
);
})}
</div>
</div>
)}
{/* Per-endpoint table */}
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Limit
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data?.rate_limits.map((rl) => (
<tr key={rl.endpoint} className="hover:bg-bg-card-hover/50 transition-colors">
<td className="px-6 py-4 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-4">
<span className="inline-flex px-2 py-0.5 rounded-full text-xs font-medium bg-primary/10 text-primary border border-primary/20">
{rl.limit}
</span>
</td>
<td className="px-6 py-4 text-right text-text-primary font-semibold">
{rl.total_requests}
</td>
<td className="px-6 py-4 text-right">
<span className={rl.rejected_requests > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{rl.rejected_requests}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Per-IP breakdown */}
{data?.rate_limits.some((rl) => rl.by_ip.length > 0) && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="px-6 py-4 border-b border-primary/10">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Per-IP Breakdown
</h3>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
IP Address
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data.rate_limits.flatMap((rl) =>
rl.by_ip.map((ipEntry) => (
<tr
key={`${rl.endpoint}-${ipEntry.ip}`}
className="hover:bg-bg-card-hover/50 transition-colors"
>
<td className="px-6 py-3 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-3 font-mono text-sm text-text-secondary">{ipEntry.ip}</td>
<td className="px-6 py-3 text-right text-text-primary">{ipEntry.total}</td>
<td className="px-6 py-3 text-right">
<span className={ipEntry.rejected > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{ipEntry.rejected}
</span>
</td>
</tr>
))
)}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}
+10 -1
View File
@@ -1,10 +1,12 @@
import { useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { useMutation, useQuery } from '@tanstack/react-query';
import { analysisApi, exportApi } from '../api/client';
import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react';
import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown, History } from 'lucide-react';
import type { CompanyAnalysis } from '../types';
export function Analysis() {
const navigate = useNavigate();
const [companyName, setCompanyName] = useState('');
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<CompanyAnalysis | null>(null);
@@ -157,6 +159,13 @@ export function Analysis() {
<FileText size={14} />
Export PDF
</button>
<button
onClick={() => navigate(`/history-diff?company=${encodeURIComponent(result.company_name)}`)}
className="flex items-center gap-2 text-sm bg-secondary/20 hover:bg-secondary/30 text-secondary font-medium px-3 py-1.5 rounded-lg transition-colors"
>
<History size={14} />
Compare with previous
</button>
</div>
</div>
<div className="prose dark:prose-invert max-w-none">
+249
View File
@@ -0,0 +1,249 @@
import { useState } from 'react';
import { useSearchParams } from 'react-router-dom';
import { useQuery } from '@tanstack/react-query';
import { analysisApi } from '../api/client';
import type { AnalysisHistoryItem, AnalysisDiff } from '../api/client';
import { History, ArrowRight, Plus, Minus, AlertCircle, Search } from 'lucide-react';
export function HistoryDiff() {
const [searchParams, setSearchParams] = useSearchParams();
const [companyInput, setCompanyInput] = useState(searchParams.get('company') || '');
const company = searchParams.get('company') || '';
const fromId = searchParams.get('from');
const toId = searchParams.get('to');
// Fetch history when a company is selected
const historyQuery = useQuery({
queryKey: ['history', company],
queryFn: () => analysisApi.getCompanyHistory(company),
enabled: !!company,
});
// Fetch diff when both IDs are selected
const diffQuery = useQuery<AnalysisDiff>({
queryKey: ['diff', company, fromId, toId],
queryFn: () => analysisApi.diffAnalyses(company, Number(fromId), Number(toId)),
enabled: !!company && !!fromId && !!toId,
});
const handleSearch = (e: React.FormEvent) => {
e.preventDefault();
const name = companyInput.trim();
if (name) {
setSearchParams({ company: name });
}
};
const handleSelectRuns = (from: number, to: number) => {
setSearchParams({ company, from: String(from), to: String(to) });
};
const history: AnalysisHistoryItem[] = historyQuery.data || [];
return (
<div className="space-y-6">
{/* Header */}
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Historical Analysis Diff
</h2>
<p className="text-text-secondary">
Compare analysis runs for the same company to see what changed between them.
</p>
</div>
{/* Company Search */}
<form onSubmit={handleSearch} className="flex gap-4">
<div className="flex-1 relative">
<Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} />
<input
type="text"
value={companyInput}
onChange={(e) => setCompanyInput(e.target.value)}
placeholder="Enter company name (e.g., nvidia)"
className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all"
/>
</div>
<button
type="submit"
disabled={!companyInput.trim()}
className="bg-gradient-to-r from-primary to-primary-dark text-white font-semibold py-3 px-6 rounded-xl hover:shadow-lg hover:shadow-primary/30 transition-all disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
<History size={18} />
Load History
</button>
</form>
{/* History list */}
{company && historyQuery.isLoading && (
<div className="text-text-secondary animate-pulse">Loading analysis history...</div>
)}
{company && historyQuery.isError && (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load history. Check the company name and try again.</span>
</div>
)}
{company && history.length === 0 && !historyQuery.isLoading && (
<div className="text-text-secondary">No analysis history found for "{company}".</div>
)}
{history.length >= 2 && (
<div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-4">
Select Two Runs to Compare
</h3>
<div className="space-y-2">
{history.map((item, idx) => {
const next = history[idx + 1];
if (!next) return null;
const isSelected =
fromId === String(next.id) && toId === String(item.id);
return (
<button
key={item.id}
onClick={() => handleSelectRuns(next.id, item.id)}
className={`w-full text-left flex items-center gap-3 px-4 py-3 rounded-xl border transition-all ${
isSelected
? 'border-primary bg-primary/10'
: 'border-primary/15 hover:border-primary/40 hover:bg-primary/5'
}`}
>
<span className="text-sm text-text-secondary font-mono">
#{next.id}
</span>
<span className="text-xs text-text-secondary">
{new Date(next.timestamp).toLocaleString()}
</span>
<ArrowRight size={14} className="text-primary" />
<span className="text-sm text-text-secondary font-mono">
#{item.id}
</span>
<span className="text-xs text-text-secondary">
{new Date(item.timestamp).toLocaleString()}
</span>
{item.model && (
<span className="ml-auto text-xs bg-primary/20 text-primary px-2 py-0.5 rounded">
{item.model}
</span>
)}
</button>
);
})}
</div>
</div>
)}
{/* Diff Results */}
{diffQuery.isLoading && (
<div className="text-text-secondary animate-pulse">Computing diff...</div>
)}
{diffQuery.isError && (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to compute diff. One or both analysis IDs may not exist.</span>
</div>
)}
{diffQuery.data && <DiffView diff={diffQuery.data} />}
</div>
);
}
function DiffView({ diff }: { diff: AnalysisDiff }) {
return (
<div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6 space-y-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2">
Diff: #{diff.from_id} &rarr; #{diff.to_id}
</h3>
{/* Summary */}
<div className="bg-primary/5 border border-primary/20 rounded-xl p-4">
<div className="text-sm font-medium text-text-primary">{diff.summary}</div>
<div className="flex items-center gap-4 mt-2 text-xs text-text-secondary">
<span>{new Date(diff.from_timestamp).toLocaleString()}</span>
<ArrowRight size={12} />
<span>{new Date(diff.to_timestamp).toLocaleString()}</span>
</div>
</div>
{/* Patent count delta */}
<div className="flex items-center gap-3">
<span className="text-sm text-text-secondary">Patent mention delta:</span>
<span
className={`text-lg font-bold ${
diff.patent_count_delta > 0
? 'text-success'
: diff.patent_count_delta < 0
? 'text-error'
: 'text-text-secondary'
}`}
>
{diff.patent_count_delta > 0 ? '+' : ''}
{diff.patent_count_delta}
</span>
</div>
{/* Added patents */}
{diff.added_patents.length > 0 && (
<div>
<h4 className="text-sm font-semibold text-success flex items-center gap-1 mb-2">
<Plus size={14} />
New Patents ({diff.added_patents.length})
</h4>
<div className="flex flex-wrap gap-2">
{diff.added_patents.map((p) => (
<span
key={p}
className="text-xs bg-success/10 border border-success/20 text-success px-2 py-1 rounded font-mono"
>
{p}
</span>
))}
</div>
</div>
)}
{/* Removed patents */}
{diff.removed_patents.length > 0 && (
<div>
<h4 className="text-sm font-semibold text-error flex items-center gap-1 mb-2">
<Minus size={14} />
Removed Patents ({diff.removed_patents.length})
</h4>
<div className="flex flex-wrap gap-2">
{diff.removed_patents.map((p) => (
<span
key={p}
className="text-xs bg-error/10 border border-error/20 text-error px-2 py-1 rounded font-mono"
>
{p}
</span>
))}
</div>
</div>
)}
{/* Changed fields */}
{Object.keys(diff.changed_fields).length > 0 && (
<div>
<h4 className="text-sm font-semibold text-text-primary mb-2">Changed Fields</h4>
<div className="space-y-1">
{Object.entries(diff.changed_fields).map(([field, vals]) => (
<div key={field} className="flex items-center gap-2 text-sm">
<span className="text-text-secondary font-mono">{field}:</span>
<span className="text-error line-through">{vals.from || 'null'}</span>
<ArrowRight size={12} className="text-text-secondary" />
<span className="text-success">{vals.to || 'null'}</span>
</div>
))}
</div>
</div>
)}
</div>
);
}
+244
View File
@@ -0,0 +1,244 @@
"""Tests for historical analysis diff endpoint."""
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import AnalysisDiffResponse, _compute_analysis_diff, _extract_patent_ids, app
from SPARC.auth import UserResponse, get_current_user
# ---------- helpers ----------
def _mock_user():
"""Return a fake authenticated user for dependency override."""
return UserResponse(
id=1,
email="test@example.com",
role="user",
created_at=datetime(2025, 1, 1),
)
@pytest.fixture
def auth_client():
"""TestClient with auth dependency overridden."""
app.dependency_overrides[get_current_user] = _mock_user
client = TestClient(app, raise_server_exceptions=False)
yield client
app.dependency_overrides.clear()
# ---------- unit tests for helpers ----------
class TestExtractPatentIds:
"""Test _extract_patent_ids utility."""
def test_extracts_standard_ids(self):
text = "Patent US-12345678-B2 covers the device. Also see US-9876543-A1."
ids = _extract_patent_ids(text)
assert "US-12345678-B2" in ids
assert "US-9876543-A1" in ids
def test_empty_text(self):
assert _extract_patent_ids("") == set()
assert _extract_patent_ids(None) == set() # type: ignore[arg-type]
class TestComputeAnalysisDiff:
"""Test _compute_analysis_diff logic."""
def test_identical_analyses(self):
rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 is notable.",
"timestamp": datetime(2025, 5, 1),
}
diff = _compute_analysis_diff(rec, dict(rec, id=2, timestamp=datetime(2025, 5, 2)))
assert diff.patent_count_delta == 0
assert diff.added_patents == []
assert diff.removed_patents == []
def test_added_and_removed_patents(self):
from_rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-11111111-A1.",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 2,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-99999999-B1.",
"timestamp": datetime(2025, 5, 2),
}
diff = _compute_analysis_diff(from_rec, to_rec)
assert "US-99999999-B1" in diff.added_patents
assert "US-11111111-A1" in diff.removed_patents
assert diff.patent_count_delta == 0 # one added, one removed
def test_model_change_detected(self):
from_rec = {
"id": 1,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 2,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "anthropic/claude-3.5-sonnet",
"response": "",
"timestamp": datetime(2025, 5, 2),
}
diff = _compute_analysis_diff(from_rec, to_rec)
assert "model" in diff.changed_fields
assert diff.changed_fields["model"]["from"] == "openai/gpt-4o"
assert diff.changed_fields["model"]["to"] == "anthropic/claude-3.5-sonnet"
# ---------- API endpoint tests ----------
class TestDiffEndpoint:
"""Test GET /analyze/{company_name}/diff."""
@patch("SPARC.api._get_job_db")
def test_happy_path(self, mock_get_db, auth_client):
"""Diff returns structured response when both IDs exist."""
db = MagicMock()
mock_get_db.return_value = db
from_rec = {
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 found.",
"timestamp": datetime(2025, 5, 1),
}
to_rec = {
"id": 20,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "Patent US-12345678-B2 and US-99999999-A1 found.",
"timestamp": datetime(2025, 5, 10),
}
db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else to_rec
response = auth_client.get("/analyze/nvidia/diff?from=10&to=20")
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "nvidia"
assert data["from_id"] == 10
assert data["to_id"] == 20
assert "US-99999999-A1" in data["added_patents"]
assert data["patent_count_delta"] == 1
@patch("SPARC.api._get_job_db")
def test_from_id_not_found(self, mock_get_db, auth_client):
"""Returns 404 when 'from' analysis ID doesn't exist."""
db = MagicMock()
mock_get_db.return_value = db
db.get_analysis_by_id.return_value = None
response = auth_client.get("/analyze/nvidia/diff?from=999&to=1000")
assert response.status_code == 404
assert "999" in response.json()["detail"]
@patch("SPARC.api._get_job_db")
def test_to_id_not_found(self, mock_get_db, auth_client):
"""Returns 404 when 'to' analysis ID doesn't exist."""
db = MagicMock()
mock_get_db.return_value = db
from_rec = {
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
db.get_analysis_by_id.side_effect = lambda aid: from_rec if aid == 10 else None
response = auth_client.get("/analyze/nvidia/diff?from=10&to=999")
assert response.status_code == 404
assert "999" in response.json()["detail"]
@patch("SPARC.api._get_job_db")
def test_company_mismatch(self, mock_get_db, auth_client):
"""Returns 404 when analysis belongs to a different company."""
db = MagicMock()
mock_get_db.return_value = db
rec = {
"id": 10,
"company_name": "intel",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "",
"timestamp": datetime(2025, 5, 1),
}
db.get_analysis_by_id.return_value = rec
response = auth_client.get("/analyze/nvidia/diff?from=10&to=20")
assert response.status_code == 404
class TestHistoryEndpoint:
"""Test GET /analyze/{company_name}/history."""
@patch("SPARC.api._get_job_db")
def test_returns_history_list(self, mock_get_db, auth_client):
"""History endpoint returns list of past analysis runs."""
db = MagicMock()
mock_get_db.return_value = db
db.list_company_analyses.return_value = [
{
"id": 20,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "...",
"timestamp": datetime(2025, 5, 10),
},
{
"id": 10,
"company_name": "nvidia",
"analysis_type": "portfolio",
"model": "openai/gpt-4o",
"response": "...",
"timestamp": datetime(2025, 5, 1),
},
]
response = auth_client.get("/analyze/nvidia/history")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["id"] == 20
assert data[1]["id"] == 10
@patch("SPARC.api._get_job_db")
def test_empty_history(self, mock_get_db, auth_client):
"""History endpoint returns empty list when no analyses exist."""
db = MagicMock()
mock_get_db.return_value = db
db.list_company_analyses.return_value = []
response = auth_client.get("/analyze/nvidia/history")
assert response.status_code == 200
assert response.json() == []
+211
View File
@@ -0,0 +1,211 @@
"""Tests for analyze_single_patent auto-download path.
Covers issue #1661:
- PDF exists on disk: direct analysis (happy path)
- PDF not on disk, cached link exists: auto-download and analyze
- PDF not on disk, no cached link: FileNotFoundError
- Analysis failure after PDF found: graceful error message
- Model override parameter passthrough
"""
import os
from unittest.mock import MagicMock, patch
import pytest
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent
@pytest.fixture(autouse=True)
def mock_db(mocker):
"""Mock DatabaseClient so no real DB is needed."""
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
mock_db_instance = MagicMock()
mock_db_instance.get_cached_patent.return_value = None
mock_db_instance.get_cached_serp_query.return_value = None
mock_db_cls.return_value = mock_db_instance
return mock_db_instance
@pytest.fixture
def analyzer(mocker, mock_db):
"""Create a CompanyAnalyzer with mocked LLM and DB."""
mocker.patch("SPARC.analyzer.LLMAnalyzer")
return CompanyAnalyzer(openrouter_api_key="test-key")
class TestAnalyzeSinglePatentAutoDownload:
"""Test the auto-download logic in analyze_single_patent."""
def test_pdf_on_disk_analyzed_directly(self, analyzer, mocker, tmp_path):
"""When PDF exists on disk, it is analyzed directly without download."""
patent_id = "US-11234567-B2"
# Create the patents dir and PDF file
patents_dir = tmp_path / "patents"
patents_dir.mkdir()
pdf_path = patents_dir / f"{patent_id}.pdf"
pdf_path.write_bytes(b"fake PDF content")
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_parse.return_value = {"abstract": "test", "claims": "test claims"}
mock_minimize.return_value = "minimized content"
analyzer.llm_analyzer.analyze_patent_content.return_value = "Good patent."
# Change cwd so patents/{patent_id}.pdf resolves to our tmp_path
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
result = analyzer.analyze_single_patent(patent_id, "TestCo")
finally:
os.chdir(original_cwd)
assert result == "Good patent."
# DB cache should not have been queried since file existed
analyzer.db.get_cached_patent.assert_not_called()
def test_auto_download_from_cached_link(self, analyzer, mocker, tmp_path):
"""When PDF is not on disk but link is cached, auto-download occurs."""
patent_id = "US-99887766-A1"
# No patents dir exists (PDF not on disk)
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
downloaded_patent = Patent(patent_id=patent_id, pdf_link="https://example.com/patent.pdf")
downloaded_patent.pdf_path = f"patents/{patent_id}.pdf"
mock_save.return_value = downloaded_patent
# Cached patent has a PDF link
analyzer.db.get_cached_patent.return_value = {
"patent_id": patent_id,
"pdf_link": "https://example.com/patent.pdf",
}
# Mock the rest of the analysis pipeline
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_parse.return_value = {"abstract": "test abstract"}
mock_minimize.return_value = "minimized content"
analyzer.llm_analyzer.analyze_patent_content.return_value = "Strong innovation."
# Change cwd so patents/{patent_id}.pdf does NOT exist
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
result = analyzer.analyze_single_patent(patent_id, "DownloadCo")
finally:
os.chdir(original_cwd)
assert result == "Strong innovation."
analyzer.db.get_cached_patent.assert_called_once_with(patent_id)
mock_save.assert_called_once()
# Verify the Patent passed to save_patents has the correct ID and link
saved_patent = mock_save.call_args[0][0]
assert saved_patent.patent_id == patent_id
assert saved_patent.pdf_link == "https://example.com/patent.pdf"
def test_no_cached_link_raises_file_not_found(self, analyzer, mocker, tmp_path):
"""When PDF is not on disk and no cached link, FileNotFoundError raised."""
patent_id = "US-00000000-X1"
analyzer.db.get_cached_patent.return_value = None
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
with pytest.raises(FileNotFoundError, match="no download link is cached"):
analyzer.analyze_single_patent(patent_id, "MissingCo")
finally:
os.chdir(original_cwd)
def test_cached_patent_without_pdf_link_raises(self, analyzer, mocker, tmp_path):
"""When cached patent exists but has no pdf_link, FileNotFoundError raised."""
patent_id = "US-11111111-B1"
analyzer.db.get_cached_patent.return_value = {
"patent_id": patent_id,
"pdf_link": None,
}
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
with pytest.raises(FileNotFoundError, match="no download link is cached"):
analyzer.analyze_single_patent(patent_id, "NoPDFCo")
finally:
os.chdir(original_cwd)
def test_analysis_exception_returns_error_message(self, analyzer, mocker, tmp_path):
"""When analysis pipeline fails, returns error string instead of raising."""
patent_id = "US-22222222-A2"
# Create the PDF on disk so it skips download
patents_dir = tmp_path / "patents"
patents_dir.mkdir()
(patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF")
# Parse fails
mocker.patch(
"SPARC.analyzer.SERP.parse_patent_pdf",
side_effect=ValueError("Corrupt PDF"),
)
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
result = analyzer.analyze_single_patent(patent_id, "ErrorCo")
finally:
os.chdir(original_cwd)
assert "Failed to analyze patent" in result
assert "Corrupt PDF" in result
def test_model_override_passed_to_llm(self, analyzer, mocker, tmp_path):
"""The model parameter is forwarded to the LLM analyzer."""
patent_id = "US-33333333-B2"
patents_dir = tmp_path / "patents"
patents_dir.mkdir()
(patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF")
mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf", return_value={"abstract": "test"})
mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm", return_value="content")
analyzer.llm_analyzer.analyze_patent_content.return_value = "Analysis result."
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
result = analyzer.analyze_single_patent(
patent_id, "ModelCo", model="openai/gpt-4o"
)
finally:
os.chdir(original_cwd)
assert result == "Analysis result."
analyzer.llm_analyzer.analyze_patent_content.assert_called_once_with(
patent_content="content",
company_name="ModelCo",
model="openai/gpt-4o",
)
def test_file_not_found_during_parse_re_raised(self, analyzer, mocker, tmp_path):
"""FileNotFoundError during parsing is re-raised, not caught."""
patent_id = "US-44444444-C1"
patents_dir = tmp_path / "patents"
patents_dir.mkdir()
(patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF")
mocker.patch(
"SPARC.analyzer.SERP.parse_patent_pdf",
side_effect=FileNotFoundError("PDF file vanished"),
)
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
with pytest.raises(FileNotFoundError, match="PDF file vanished"):
analyzer.analyze_single_patent(patent_id, "VanishCo")
finally:
os.chdir(original_cwd)
+319
View File
@@ -0,0 +1,319 @@
"""Tests for user-level API key generation, listing, revocation, and authentication.
Covers all acceptance criteria from issue #1673:
1. Users can create API keys (POST /auth/apikeys)
2. Users can list their active key IDs (GET /auth/apikeys)
3. Users can revoke keys (DELETE /auth/apikeys/{key_id})
4. API requests authenticated with a valid API key work on protected endpoints
5. Revoked keys are immediately rejected
6. Plaintext key is shown only at creation time
All tests use mocked DB fixtures and require no live database.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import (
create_access_token,
generate_api_key,
hash_api_key,
verify_api_key,
)
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_user():
return {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth and api endpoints."""
db = MagicMock()
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
db.create_api_key.return_value = None
db.list_api_keys.return_value = []
db.delete_api_key.return_value = False
db.get_all_api_key_hashes.return_value = []
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
class TestCreateApiKey:
"""POST /auth/apikeys"""
def test_create_key_returns_plaintext_and_id(self, client, mock_db):
"""Creating a key returns the plaintext key and metadata."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 42,
"user_id": user["id"],
"label": "my-ci-key",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
json={"label": "my-ci-key"},
headers=_auth_header(user),
)
assert response.status_code == 200
data = response.json()
assert data["id"] == 42
assert len(data["key"]) == 64 # 32 bytes hex = 64 chars
assert data["label"] == "my-ci-key"
assert "created_at" in data
# Verify the hash passed to DB is valid for the returned key
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[1].get("key_hash") or call_args[0][1]
assert verify_api_key(data["key"], stored_hash)
def test_create_key_without_label(self, client, mock_db):
"""Creating a key without a label should work."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.create_api_key.return_value = {
"id": 1,
"user_id": user["id"],
"label": None,
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/apikeys",
headers=_auth_header(user),
)
assert response.status_code == 200
assert response.json()["label"] is None
def test_create_key_requires_auth(self, client):
"""Creating a key without auth should fail."""
response = client.post("/auth/apikeys")
assert response.status_code == 401
class TestListApiKeys:
"""GET /auth/apikeys"""
def test_list_keys_returns_metadata_only(self, client, mock_db):
"""Listing keys should return IDs and labels, not secrets."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = [
{"id": 1, "label": "key-1", "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc)},
{"id": 2, "label": None, "created_at": datetime(2025, 2, 1, tzinfo=timezone.utc)},
]
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["id"] == 1
assert data[0]["label"] == "key-1"
# Ensure no secret key is exposed
for item in data:
assert "key" not in item
assert "key_hash" not in item
def test_list_keys_empty(self, client, mock_db):
"""User with no keys gets an empty list."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.list_api_keys.return_value = []
response = client.get("/auth/apikeys", headers=_auth_header(user))
assert response.status_code == 200
assert response.json() == []
class TestRevokeApiKey:
"""DELETE /auth/apikeys/{key_id}"""
def test_revoke_existing_key(self, client, mock_db):
"""Revoking an owned key should succeed."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = True
response = client.delete("/auth/apikeys/42", headers=_auth_header(user))
assert response.status_code == 200
assert "revoked" in response.json()["message"].lower()
mock_db.delete_api_key.assert_called_once_with(42, user["id"])
def test_revoke_nonexistent_key_returns_404(self, client, mock_db):
"""Revoking a key that doesn't exist (or isn't owned) returns 404."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
mock_db.delete_api_key.return_value = False
response = client.delete("/auth/apikeys/999", headers=_auth_header(user))
assert response.status_code == 404
class TestApiKeyAuthentication:
"""Using X-API-Key header on protected endpoints."""
def test_valid_api_key_accesses_protected_endpoint(self, client, mock_db):
"""A valid API key should authenticate and access /auth/me."""
user = _make_user()
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": user["id"]},
]
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 200
data = response.json()
assert data["email"] == user["email"]
assert data["id"] == user["id"]
def test_invalid_api_key_returns_401(self, client, mock_db):
"""An invalid API key should return 401."""
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "bad-key"})
assert response.status_code == 401
assert "invalid api key" in response.json()["detail"].lower()
def test_revoked_key_returns_401(self, client, mock_db):
"""After revocation, using the key should return 401."""
# Simulate revoked key: no matching hashes in DB
mock_db.get_all_api_key_hashes.return_value = []
response = client.get("/auth/me", headers={"X-API-Key": "a" * 64})
assert response.status_code == 401
def test_api_key_for_deleted_user_returns_401(self, client, mock_db):
"""An API key whose user no longer exists should return 401."""
plaintext = generate_api_key()
hashed = hash_api_key(plaintext)
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": hashed, "user_id": 999},
]
mock_db.get_user_by_id.return_value = None # user deleted
response = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert response.status_code == 401
def test_no_auth_at_all_returns_401(self, client, mock_db):
"""No auth header at all should return 401."""
response = client.get("/auth/me")
assert response.status_code == 401
class TestApiKeyFullFlow:
"""End-to-end flow: create key, use it, revoke it, try again."""
def test_create_use_revoke_flow(self, client, mock_db):
"""Simulate full lifecycle of an API key."""
user = _make_user()
mock_db.get_user_by_id.return_value = user
# Step 1: Create key
mock_db.create_api_key.return_value = {
"id": 10,
"user_id": user["id"],
"label": "test",
"created_at": datetime(2025, 6, 1, tzinfo=timezone.utc),
}
create_resp = client.post(
"/auth/apikeys",
json={"label": "test"},
headers=_auth_header(user),
)
assert create_resp.status_code == 200
plaintext = create_resp.json()["key"]
# Capture the hash that was stored
call_args = mock_db.create_api_key.call_args
stored_hash = call_args.kwargs.get("key_hash") or call_args[0][1]
# Step 2: Use key on protected endpoint
mock_db.get_all_api_key_hashes.return_value = [
{"key_hash": stored_hash, "user_id": user["id"]},
]
use_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert use_resp.status_code == 200
assert use_resp.json()["email"] == user["email"]
# Step 3: Revoke key
mock_db.delete_api_key.return_value = True
revoke_resp = client.delete("/auth/apikeys/10", headers=_auth_header(user))
assert revoke_resp.status_code == 200
# Step 4: Try using revoked key
mock_db.get_all_api_key_hashes.return_value = [] # key removed from DB
rejected_resp = client.get("/auth/me", headers={"X-API-Key": plaintext})
assert rejected_resp.status_code == 401
class TestApiKeyHelpers:
"""Unit tests for key generation and hashing helpers."""
def test_generate_api_key_length(self):
"""Generated key should be 64 hex characters (32 bytes)."""
key = generate_api_key()
assert len(key) == 64
# Should be valid hex
int(key, 16)
def test_generate_api_key_uniqueness(self):
"""Two generated keys should be different."""
k1 = generate_api_key()
k2 = generate_api_key()
assert k1 != k2
def test_hash_and_verify(self):
"""hash_api_key and verify_api_key should round-trip correctly."""
key = generate_api_key()
hashed = hash_api_key(key)
assert verify_api_key(key, hashed)
assert not verify_api_key("wrong-key", hashed)
+209 -10
View File
@@ -1,13 +1,29 @@
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access."""
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access.
from datetime import datetime, timezone
Covers all five scenarios required by issue #1624:
1. Registration (POST /auth/register)
2. Login (POST /auth/login)
3. Protected route access (GET /auth/me) -- valid, missing, expired, wrong-type tokens
4. Token refresh (POST /auth/refresh)
5. Admin-only endpoints (GET /admin/users, PATCH role, DELETE user)
All tests use mocked DB fixtures and require no live database.
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import jwt as pyjwt
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token, create_refresh_token
from SPARC.auth import (
JWT_ALGORITHM,
JWT_SECRET,
create_access_token,
create_refresh_token,
)
@pytest.fixture
@@ -171,13 +187,6 @@ class TestGetMe:
def test_expired_token_returns_401(self, client, mock_db):
"""An expired token should return 401."""
# Create a token that has already expired
from datetime import timedelta
import jwt as pyjwt
from SPARC.auth import JWT_ALGORITHM, JWT_SECRET
payload = {
"sub": "1",
"email": "user@test.com",
@@ -301,3 +310,193 @@ class TestAdminUsers:
assert response.status_code == 400
assert "own role" in response.json()["detail"].lower()
def test_role_change_nonexistent_user_returns_404(self, client, mock_db):
"""Changing role for a user that does not exist should return 404."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.update_user_role.return_value = None
response = client.patch(
"/admin/users/999/role",
json={"role": "admin"},
headers=_auth_header(admin),
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_regular_user_cannot_change_role(self, client, mock_db):
"""Non-admin user should receive 403 when trying to change roles."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.patch(
"/admin/users/1/role",
json={"role": "admin"},
headers=_auth_header(user),
)
assert response.status_code == 403
class TestAdminDeleteUser:
"""DELETE /admin/users/{user_id}"""
def test_admin_can_delete_user(self, client, mock_db):
"""Admin should be able to delete another user."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.delete_user.return_value = True
response = client.delete(
"/admin/users/2",
headers=_auth_header(admin),
)
assert response.status_code == 200
assert "deleted" in response.json()["message"].lower()
mock_db.delete_user.assert_called_once_with(2)
def test_admin_cannot_delete_self(self, client, mock_db):
"""Admin should not be able to delete themselves."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
response = client.delete(
"/admin/users/1",
headers=_auth_header(admin),
)
assert response.status_code == 400
assert "yourself" in response.json()["detail"].lower()
def test_delete_nonexistent_user_returns_404(self, client, mock_db):
"""Deleting a user that does not exist should return 404."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.delete_user.return_value = False
response = client.delete(
"/admin/users/999",
headers=_auth_header(admin),
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_regular_user_cannot_delete_user(self, client, mock_db):
"""Non-admin user should receive 403 when trying to delete users."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.delete(
"/admin/users/1",
headers=_auth_header(user),
)
assert response.status_code == 403
def test_no_token_cannot_delete_user(self, client):
"""Missing token should be rejected for delete endpoint."""
response = client.delete("/admin/users/1")
assert response.status_code in (401, 403)
class TestEdgeCases:
"""Additional edge-case tests for auth robustness."""
def test_register_invalid_email_returns_422(self, client, mock_db):
"""Registration with an invalid email format should return 422."""
response = client.post(
"/auth/register",
json={"email": "not-an-email", "password": "securepass123"},
)
assert response.status_code == 422
def test_register_short_password_returns_422(self, client, mock_db):
"""Registration with a password shorter than 8 chars should return 422."""
response = client.post(
"/auth/register",
json={"email": "user@test.com", "password": "short"},
)
assert response.status_code == 422
def test_register_missing_fields_returns_422(self, client, mock_db):
"""Registration with missing fields should return 422."""
response = client.post("/auth/register", json={})
assert response.status_code == 422
def test_login_missing_fields_returns_422(self, client, mock_db):
"""Login with missing fields should return 422."""
response = client.post("/auth/login", json={"email": "user@test.com"})
assert response.status_code == 422
def test_malformed_token_returns_401(self, client, mock_db):
"""A completely malformed token string should return 401."""
response = client.get(
"/auth/me",
headers={"Authorization": "Bearer not.a.valid.jwt.token"},
)
assert response.status_code == 401
def test_token_with_wrong_secret_returns_401(self, client, mock_db):
"""A token signed with a different secret should return 401."""
payload = {
"sub": "1",
"email": "user@test.com",
"role": "user",
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
"type": "access",
}
wrong_secret_token = pyjwt.encode(payload, "wrong-secret", algorithm=JWT_ALGORITHM)
response = client.get(
"/auth/me",
headers={"Authorization": f"Bearer {wrong_secret_token}"},
)
assert response.status_code == 401
def test_token_for_deleted_user_returns_401(self, client, mock_db):
"""A valid token for a user no longer in the DB should return 401."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = None # user was deleted
response = client.get("/auth/me", headers=_auth_header(user))
assert response.status_code == 401
def test_refresh_for_deleted_user_returns_401(self, client, mock_db):
"""Refreshing a token for a deleted user should return 401."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = None
refresh = create_refresh_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": refresh}
)
assert response.status_code == 401
def test_login_returns_decodable_tokens(self, client, mock_db):
"""Tokens returned by login should be decodable and contain expected claims."""
user = _make_regular_user()
mock_db.authenticate_user.return_value = user
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "correctpassword"},
)
data = response.json()
access_payload = pyjwt.decode(
data["access_token"], JWT_SECRET, algorithms=[JWT_ALGORITHM]
)
assert access_payload["sub"] == str(user["id"])
assert access_payload["email"] == user["email"]
assert access_payload["type"] == "access"
refresh_payload = pyjwt.decode(
data["refresh_token"], JWT_SECRET, algorithms=[JWT_ALGORITHM]
)
assert refresh_payload["type"] == "refresh"
+157
View File
@@ -0,0 +1,157 @@
"""Tests for company name input validation on analysis endpoints."""
from datetime import datetime
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer so valid requests succeed."""
mock = Mock()
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
company_name="nvidia",
analysis="Test analysis",
patent_count=1,
success=True,
timestamp=datetime.now(),
)
mocker.patch("SPARC.api._analyzer", mock)
return mock
class TestCompanyNameValidation:
"""Test that company names are validated on analysis endpoints."""
# --- Too short ---
def test_single_char_rejected(self, client, mock_analyzer):
"""A one-character company name should be rejected."""
response = client.get("/analyze/X")
assert response.status_code == 422
# --- Too long ---
def test_over_100_chars_rejected(self, client, mock_analyzer):
"""A company name longer than 100 characters should be rejected."""
long_name = "A" * 101
response = client.get(f"/analyze/{long_name}")
assert response.status_code == 422
# --- Special characters ---
@pytest.mark.parametrize(
"bad_name",
[
"nvidia!",
"intel@corp",
"test#company",
"foo$bar",
"a%b",
"x^y",
"semi;colon",
"drop'table",
'say"hello',
"path/traversal",
"back\\slash",
"pipe|char",
"star*glob",
"question?mark",
"<script>",
"curly{brace}",
"equal=sign",
"plus+plus",
"comma,separated",
],
)
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
"""Company names with disallowed special characters should be rejected."""
response = client.get(f"/analyze/{bad_name}")
assert response.status_code == 422
# --- Valid names ---
@pytest.mark.parametrize(
"valid_name",
[
"nvidia",
"Intel",
"TSMC",
"Texas Instruments",
"Johnson-Johnson",
"AT&T",
"St. Jude Medical",
"3M",
"21st Century Fox",
"ab", # minimum length
"A" * 100, # maximum length
],
)
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
"""Valid company names should be accepted (200, not 422)."""
response = client.get(f"/analyze/{valid_name}")
# Should not be a validation error; 200 or other non-422 status is fine
assert response.status_code != 422
# --- Batch endpoint validation ---
def test_batch_too_short_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too short."""
response = client.post(
"/analyze/batch",
json={"companies": ["X"]},
)
assert response.status_code == 422
def test_batch_too_long_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too long."""
response = client.post(
"/analyze/batch",
json={"companies": ["A" * 101]},
)
assert response.status_code == 422
def test_batch_special_chars_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names with special chars."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia!", "intel"]},
)
assert response.status_code == 422
def test_batch_valid_names_accepted(self, client, mock_analyzer):
"""Batch endpoint should accept valid company names."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "Intel", "AT&T"]},
)
assert response.status_code != 422
# --- Name must start with alphanumeric ---
def test_leading_space_rejected(self, client, mock_analyzer):
"""Company name starting with a space should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": [" nvidia"]},
)
assert response.status_code == 422
def test_leading_hyphen_rejected(self, client, mock_analyzer):
"""Company name starting with a hyphen should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": ["-nvidia"]},
)
assert response.status_code == 422
+224
View File
@@ -0,0 +1,224 @@
"""Tests for export endpoints: CSV and PDF export of analysis results.
Covers issue #1655:
- GET /export/{company_name} (CSV export)
- GET /export/{company_name}/pdf (PDF export)
All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by export and auth endpoints."""
db = MagicMock()
# Default: user exists for auth
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
# Mock get_conn for export queries
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
db._mock_cursor = mock_cursor
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
"""Create an Authorization header with a valid access token."""
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _sample_rows():
"""Return sample llm_messages rows as tuples (matching cursor.fetchall format)."""
return [
(
"NVIDIA",
"company_analysis",
"anthropic/claude-3.5-sonnet",
"Strong AI patent portfolio with focus on GPU architectures.",
datetime(2025, 6, 15, 10, 30, 0),
),
(
"NVIDIA",
"patent_analysis",
"openai/gpt-4o",
"Patent US-12345678-B2 covers novel tensor core design.",
datetime(2025, 6, 14, 9, 0, 0),
),
]
class TestCSVExport:
"""GET /export/{company_name} -- CSV export."""
def test_csv_export_success(self, client, mock_db):
"""Valid company with results returns a CSV file."""
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
response = client.get("/export/NVIDIA", headers=_auth_header())
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/csv")
assert "attachment" in response.headers.get("content-disposition", "")
assert "sparc_nvidia_export.csv" in response.headers["content-disposition"]
# Verify CSV content (CSV uses \r\n line endings)
lines = response.text.strip().split("\n")
assert len(lines) == 3 # header + 2 data rows
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
assert "NVIDIA" in lines[1]
assert "company_analysis" in lines[1]
def test_csv_export_no_results_returns_404(self, client, mock_db):
"""Unknown company returns 404."""
mock_db._mock_cursor.fetchall.return_value = []
response = client.get("/export/nonexistent", headers=_auth_header())
assert response.status_code == 404
assert "No analysis results found" in response.json()["detail"]
def test_csv_export_unauthenticated_returns_401(self, client):
"""Request without token returns 401."""
response = client.get("/export/NVIDIA")
assert response.status_code == 401
def test_csv_export_invalid_token_returns_401(self, client):
"""Request with invalid token returns 401."""
response = client.get(
"/export/NVIDIA",
headers={"Authorization": "Bearer invalid.token.here"},
)
assert response.status_code == 401
def test_csv_export_filename_sanitization(self, client, mock_db):
"""Company names with spaces get sanitized in the filename."""
mock_db._mock_cursor.fetchall.return_value = [
(
"Tesla Motors",
"company_analysis",
"anthropic/claude-3.5-sonnet",
"EV patent portfolio analysis.",
datetime(2025, 6, 15, 10, 0, 0),
),
]
response = client.get("/export/Tesla Motors", headers=_auth_header())
assert response.status_code == 200
assert "tesla_motors" in response.headers["content-disposition"]
def test_csv_export_single_row(self, client, mock_db):
"""Single analysis result produces valid CSV with one data row."""
mock_db._mock_cursor.fetchall.return_value = [_sample_rows()[0]]
response = client.get("/export/NVIDIA", headers=_auth_header())
assert response.status_code == 200
lines = response.text.strip().split("\n")
assert len(lines) == 2 # header + 1 data row
class TestPDFExport:
"""GET /export/{company_name}/pdf -- PDF report export."""
def test_pdf_export_success(self, client, mock_db):
"""Valid company with results returns a PDF file."""
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
assert response.status_code == 200
assert response.headers["content-type"] == "application/pdf"
assert "attachment" in response.headers.get("content-disposition", "")
# PDF files start with %PDF
assert response.content[:4] == b"%PDF"
def test_pdf_export_no_results_returns_404(self, client, mock_db):
"""Unknown company returns 404."""
mock_db._mock_cursor.fetchall.return_value = []
response = client.get("/export/nonexistent/pdf", headers=_auth_header())
assert response.status_code == 404
assert "No analysis results found" in response.json()["detail"]
def test_pdf_export_unauthenticated_returns_401(self, client):
"""Request without token returns 401."""
response = client.get("/export/NVIDIA/pdf")
assert response.status_code == 401
def test_pdf_export_invalid_token_returns_401(self, client):
"""Request with invalid token returns 401."""
response = client.get(
"/export/NVIDIA/pdf",
headers={"Authorization": "Bearer invalid.token.here"},
)
assert response.status_code == 401
def test_pdf_export_filename_contains_date(self, client, mock_db):
"""PDF filename includes the analysis date."""
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
assert response.status_code == 200
disposition = response.headers["content-disposition"]
assert "nvidia-analysis-" in disposition
assert ".pdf" in disposition
def test_pdf_export_special_chars_in_response(self, client, mock_db):
"""Analysis text with XML-special chars (<, >, &) does not break PDF generation."""
rows = [
(
"TestCo",
"company_analysis",
"anthropic/claude-3.5-sonnet",
"Revenue > $1B & growth <20% for Q4. Test <html> escaping.",
datetime(2025, 6, 15, 10, 0, 0),
),
]
mock_db._mock_cursor.fetchall.return_value = rows
response = client.get("/export/TestCo/pdf", headers=_auth_header())
assert response.status_code == 200
assert response.content[:4] == b"%PDF"
def test_pdf_export_multiple_analyses(self, client, mock_db):
"""Multiple analysis records produce a valid PDF with content."""
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
assert response.status_code == 200
# PDF should have reasonable size (more than just headers)
assert len(response.content) > 500
+169
View File
@@ -0,0 +1,169 @@
"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints."""
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"):
"""Create a fake analysis row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
return {
"id": id_,
"company_name": company,
"analysis_type": "patent_portfolio",
"model": "openai/gpt-4o",
"response": f"Analysis for {company}",
"timestamp": ts,
}
def _make_job_row(job_id: str, minutes_ago: int = 0, status: str = "completed"):
"""Create a fake job row dict."""
ts = datetime.now() - timedelta(minutes=minutes_ago)
return {
"job_id": job_id,
"status": status,
"progress": 100 if status == "completed" else 0,
"total_companies": 1,
"completed_companies": 1 if status == "completed" else 0,
"result": None,
"error": None,
"created_at": ts,
}
class TestAnalyzeBatchGetPagination:
"""Test cursor-based pagination on GET /analyze/batch."""
@patch("SPARC.api._get_job_db")
def test_returns_items_and_no_cursor_when_less_than_limit(self, mock_get_db, client):
"""When fewer results than limit, next_cursor should be null."""
db = Mock()
db.list_analyses.return_value = [
_make_analysis_row(1, minutes_ago=10),
_make_analysis_row(2, minutes_ago=20),
]
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=10")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 2
assert data["next_cursor"] is None
@patch("SPARC.api._get_job_db")
def test_returns_cursor_when_more_results_exist(self, mock_get_db, client):
"""When more results exist than limit, next_cursor should be set."""
db = Mock()
# Return limit+1 rows to simulate more data
rows = [_make_analysis_row(i, minutes_ago=i) for i in range(4)]
db.list_analyses.return_value = rows
mock_get_db.return_value = db
response = client.get("/analyze/batch?limit=3")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["next_cursor"] is not None
@patch("SPARC.api._get_job_db")
def test_cursor_passed_to_db(self, mock_get_db, client):
"""The cursor query param should be forwarded to the database layer."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42")
db.list_analyses.assert_called_once()
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \
(call_kwargs[1].get("cursor") == "2025-01-01T00:00:00|42" if len(call_kwargs) > 1 else False)
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
"""Default limit should be 50."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch")
call_kwargs = db.list_analyses.call_args
# The endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/analyze/batch?limit=201")
assert response.status_code == 422
def test_limit_zero_rejected(self, client):
"""Limit < 1 should be rejected with 422."""
response = client.get("/analyze/batch?limit=0")
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
def test_company_name_filter(self, mock_get_db, client):
"""The company_name filter should be forwarded to the database."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
client.get("/analyze/batch?company_name=intel")
call_kwargs = db.list_analyses.call_args
assert call_kwargs.kwargs.get("company_name") == "intel" or \
"intel" in (call_kwargs.args if call_kwargs.args else [])
@patch("SPARC.api._get_job_db")
def test_empty_result_set(self, mock_get_db, client):
"""Empty result set returns empty items and null cursor."""
db = Mock()
db.list_analyses.return_value = []
mock_get_db.return_value = db
response = client.get("/analyze/batch")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["next_cursor"] is None
class TestJobsPaginationDefaults:
"""Test that /jobs endpoint uses updated defaults."""
@patch("SPARC.api._get_job_db")
def test_default_limit_is_50(self, mock_get_db, client):
"""Default limit should now be 50."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
client.get("/jobs")
call_kwargs = db.list_jobs.call_args
# Endpoint requests limit+1 from DB, so 51
assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51
def test_limit_over_200_rejected(self, client):
"""Limit > 200 should be rejected with 422."""
response = client.get("/jobs?limit=201")
assert response.status_code == 422
@patch("SPARC.api._get_job_db")
def test_limit_200_accepted(self, mock_get_db, client):
"""Limit of exactly 200 should be accepted."""
db = Mock()
db.list_jobs.return_value = []
mock_get_db.return_value = db
response = client.get("/jobs?limit=200")
assert response.status_code == 200
+178
View File
@@ -0,0 +1,178 @@
"""Tests for the /admin/rate-limits endpoint."""
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from SPARC import api
from SPARC.api import app
from SPARC.auth import UserResponse
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_stats():
"""Reset rate limit stats between tests."""
api._rate_limit_stats.clear()
api._rejected_log.clear()
yield
api._rate_limit_stats.clear()
api._rejected_log.clear()
def _mock_admin():
"""Return a mock admin user."""
return UserResponse(id=1, email="admin@test.com", role="admin", created_at="2025-01-01T00:00:00")
def _mock_user():
"""Return a mock non-admin user."""
return UserResponse(id=2, email="user@test.com", role="user", created_at="2025-01-01T00:00:00")
class TestRateLimitAdminEndpoint:
"""Test GET /admin/rate-limits."""
def test_admin_can_access(self, client):
"""Admin users should be able to access the rate-limits endpoint."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
assert "rate_limits" in data
assert isinstance(data["rate_limits"], list)
finally:
app.dependency_overrides.clear()
def test_non_admin_rejected(self, client):
"""Non-admin users should get 401/403."""
response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403)
def test_returns_configured_endpoints(self, client):
"""Should list all rate-limited endpoints."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
assert response.status_code == 200
data = response.json()
endpoints = [rl["endpoint"] for rl in data["rate_limits"]]
assert "/auth/register" in endpoints
assert "/auth/login" in endpoints
finally:
app.dependency_overrides.clear()
def test_empty_state_shows_zero_counts(self, client):
"""When no requests have been made, counts should be zero."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0
assert rl["by_ip"] == []
assert data["throttled_24h"] == 0
assert data["throttled_over_time"] == []
finally:
app.dependency_overrides.clear()
def test_tracks_requests(self, client):
"""After making requests, the stats should reflect them."""
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "127.0.0.1")
api._track_rate_limit_request("/auth/login", "192.168.1.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
assert login_stats["total_requests"] == 3
assert login_stats["rejected_requests"] == 1
finally:
app.dependency_overrides.clear()
def test_includes_limit_config(self, client):
"""Each endpoint entry should include the rate limit config string."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
for rl in data["rate_limits"]:
assert "limit" in rl
assert isinstance(rl["limit"], str)
finally:
app.dependency_overrides.clear()
def test_per_ip_breakdown(self, client):
"""Stats should include per-IP breakdown with total and rejected counts."""
api._track_rate_limit_request("/auth/login", "10.0.0.1")
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/login", "10.0.0.2")
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
by_ip = login_stats["by_ip"]
assert len(by_ip) == 2
ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1")
assert ip1["total"] == 2
assert ip1["rejected"] == 1
ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2")
assert ip2["total"] == 1
assert ip2["rejected"] == 0
finally:
app.dependency_overrides.clear()
def test_throttled_24h_count(self, client):
"""Should report total throttled requests in the last 24 hours."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert data["throttled_24h"] == 2
finally:
app.dependency_overrides.clear()
def test_throttled_over_time_structure(self, client):
"""Throttled-over-time should be a list of {timestamp, count} buckets."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert len(data["throttled_over_time"]) >= 1
entry = data["throttled_over_time"][0]
assert "timestamp" in entry
assert "count" in entry
assert entry["count"] >= 1
finally:
app.dependency_overrides.clear()
def test_response_shape_matches_contract(self, client):
"""The full response should match the expected shape for the frontend."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
# Top-level keys
assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"}
# Each rate_limit entry
for rl in data["rate_limits"]:
assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"}
finally:
app.dependency_overrides.clear()
+263
View File
@@ -0,0 +1,263 @@
"""Tests for S3/MinIO storage backend in storage.py.
Covers issue #1660:
- S3StorageBackend read, write, exists, path_for
- Error handling: NoSuchKey, generic S3 errors, bucket auto-creation
- get_storage_backend() factory function
- LocalStorageBackend (basic sanity checks)
"""
from unittest.mock import MagicMock, patch
import pytest
from SPARC.storage import LocalStorageBackend, S3StorageBackend, get_storage_backend
# ---------- S3StorageBackend ----------
class TestS3StorageBackend:
"""Tests for the S3-compatible storage backend."""
@pytest.fixture
def s3_backend(self):
"""Create an S3StorageBackend with a fully mocked boto3 client."""
with patch.dict("sys.modules", {"boto3": MagicMock()}):
import boto3 as mock_boto
mock_s3 = MagicMock()
mock_boto.client.return_value = mock_s3
mock_s3.head_bucket.return_value = {}
backend = S3StorageBackend(
bucket="test-bucket",
endpoint_url="http://minio:9000",
access_key="minioadmin",
secret_key="minioadmin",
)
# Expose mock for assertions
backend._mock_s3 = mock_s3
yield backend
def test_write_puts_object(self, s3_backend):
"""write() calls put_object with correct bucket, key, and body."""
s3_backend.write("US-12345678-B2.pdf", b"PDF content here")
s3_backend._mock_s3.put_object.assert_called_once_with(
Bucket="test-bucket",
Key="US-12345678-B2.pdf",
Body=b"PDF content here",
ContentType="application/pdf",
)
def test_read_returns_body(self, s3_backend):
"""read() returns the Body content from get_object."""
mock_body = MagicMock()
mock_body.read.return_value = b"PDF data"
s3_backend._mock_s3.get_object.return_value = {"Body": mock_body}
result = s3_backend.read("US-12345678-B2.pdf")
assert result == b"PDF data"
s3_backend._mock_s3.get_object.assert_called_once_with(
Bucket="test-bucket",
Key="US-12345678-B2.pdf",
)
def test_read_nosuchkey_raises_file_not_found(self, s3_backend):
"""read() raises FileNotFoundError when object does not exist."""
# Create a NoSuchKey exception class on the mock
nosuchkey = type("NoSuchKey", (Exception,), {})
s3_backend._mock_s3.exceptions.NoSuchKey = nosuchkey
s3_backend._mock_s3.get_object.side_effect = nosuchkey("not found")
# Reassign s3 to trigger the except branch
s3_backend.s3 = s3_backend._mock_s3
with pytest.raises(FileNotFoundError, match="S3 object not found"):
s3_backend.read("missing.pdf")
def test_read_generic_404_raises_file_not_found(self, s3_backend):
"""read() handles generic 404 errors from S3-compatible APIs."""
nosuchkey = type("NoSuchKey", (Exception,), {})
s3_backend._mock_s3.exceptions.NoSuchKey = nosuchkey
s3_backend.s3 = s3_backend._mock_s3
s3_backend.s3.get_object.side_effect = Exception("An error occurred (404)")
with pytest.raises(FileNotFoundError, match="S3 object not found"):
s3_backend.read("missing.pdf")
def test_read_other_error_re_raises(self, s3_backend):
"""read() re-raises non-404 errors."""
nosuchkey = type("NoSuchKey", (Exception,), {})
s3_backend._mock_s3.exceptions.NoSuchKey = nosuchkey
s3_backend.s3 = s3_backend._mock_s3
s3_backend.s3.get_object.side_effect = Exception("Internal server error")
with pytest.raises(Exception, match="Internal server error"):
s3_backend.read("some-file.pdf")
def test_exists_returns_true_for_existing_object(self, s3_backend):
"""exists() returns True when head_object succeeds with content."""
s3_backend._mock_s3.head_object.return_value = {"ContentLength": 1024}
assert s3_backend.exists("US-12345678-B2.pdf") is True
def test_exists_returns_false_for_missing_object(self, s3_backend):
"""exists() returns False when head_object raises an exception."""
s3_backend._mock_s3.head_object.side_effect = Exception("Not Found")
assert s3_backend.exists("missing.pdf") is False
def test_exists_returns_false_for_zero_length(self, s3_backend):
"""exists() returns False when object has zero content length."""
s3_backend._mock_s3.head_object.return_value = {"ContentLength": 0}
assert s3_backend.exists("empty.pdf") is False
def test_path_for_returns_s3_uri(self, s3_backend):
"""path_for() returns an s3:// URI."""
path = s3_backend.path_for("US-12345678-B2.pdf")
assert path == "s3://test-bucket/US-12345678-B2.pdf"
def test_constructor_creates_bucket_if_missing(self):
"""Constructor creates the bucket if head_bucket fails."""
with patch.dict("sys.modules", {"boto3": MagicMock()}):
import boto3 as mock_boto
mock_s3 = MagicMock()
mock_boto.client.return_value = mock_s3
mock_s3.head_bucket.side_effect = Exception("Bucket not found")
S3StorageBackend(
bucket="new-bucket",
endpoint_url="http://minio:9000",
access_key="admin",
secret_key="admin",
)
mock_s3.create_bucket.assert_called_once_with(Bucket="new-bucket")
def test_constructor_handles_bucket_creation_failure(self):
"""Constructor logs warning but does not crash if bucket creation fails."""
with patch.dict("sys.modules", {"boto3": MagicMock()}):
import boto3 as mock_boto
mock_s3 = MagicMock()
mock_boto.client.return_value = mock_s3
mock_s3.head_bucket.side_effect = Exception("Bucket not found")
mock_s3.create_bucket.side_effect = Exception("Permission denied")
# Should not raise
backend = S3StorageBackend(
bucket="locked-bucket",
endpoint_url="http://minio:9000",
access_key="admin",
secret_key="admin",
)
assert backend.bucket == "locked-bucket"
def test_constructor_passes_endpoint_and_credentials(self):
"""Constructor passes endpoint_url and credentials to boto3.client."""
with patch.dict("sys.modules", {"boto3": MagicMock()}):
import boto3 as mock_boto
mock_s3 = MagicMock()
mock_boto.client.return_value = mock_s3
S3StorageBackend(
bucket="test",
endpoint_url="http://minio:9000",
access_key="mykey",
secret_key="mysecret",
)
mock_boto.client.assert_called_with(
"s3",
endpoint_url="http://minio:9000",
aws_access_key_id="mykey",
aws_secret_access_key="mysecret",
)
# ---------- LocalStorageBackend ----------
class TestLocalStorageBackend:
"""Basic sanity checks for the local filesystem backend."""
def test_write_and_read(self, tmp_path):
"""Write and read round-trip produces identical content."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
backend.write("test.pdf", b"hello world")
result = backend.read("test.pdf")
assert result == b"hello world"
def test_read_missing_file_raises(self, tmp_path):
"""Reading a non-existent file raises FileNotFoundError."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
with pytest.raises(FileNotFoundError):
backend.read("nonexistent.pdf")
def test_exists_true_for_written_file(self, tmp_path):
"""exists() returns True after writing a file."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
backend.write("test.pdf", b"data")
assert backend.exists("test.pdf") is True
def test_exists_false_for_missing_file(self, tmp_path):
"""exists() returns False for non-existent file."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
assert backend.exists("missing.pdf") is False
def test_exists_false_for_empty_file(self, tmp_path):
"""exists() returns False for zero-length file."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
backend.write("empty.pdf", b"")
assert backend.exists("empty.pdf") is False
def test_path_for_returns_full_path(self, tmp_path):
"""path_for() returns the full filesystem path."""
backend = LocalStorageBackend(base_dir=str(tmp_path))
path = backend.path_for("test.pdf")
assert path == str(tmp_path / "test.pdf")
# ---------- get_storage_backend() factory ----------
class TestGetStorageBackend:
"""Tests for the storage backend factory function."""
@patch("SPARC.storage.config")
def test_returns_local_backend_by_default(self, mock_config):
"""Default config returns LocalStorageBackend."""
mock_config.storage_backend = "local"
backend = get_storage_backend()
assert isinstance(backend, LocalStorageBackend)
@patch("SPARC.storage.config")
def test_returns_s3_backend_when_configured(self, mock_config):
"""Setting storage_backend=s3 returns S3StorageBackend."""
mock_config.storage_backend = "s3"
mock_config.s3_bucket = "test-bucket"
mock_config.s3_endpoint_url = "http://minio:9000"
mock_config.s3_access_key = "key"
mock_config.s3_secret_key = "secret"
with patch.dict("sys.modules", {"boto3": MagicMock()}):
backend = get_storage_backend()
assert isinstance(backend, S3StorageBackend)
@patch("SPARC.storage.config")
def test_case_insensitive_backend_selection(self, mock_config):
"""Backend selection is case-insensitive."""
mock_config.storage_backend = "LOCAL"
backend = get_storage_backend()
assert isinstance(backend, LocalStorageBackend)
+387
View File
@@ -0,0 +1,387 @@
"""Tests for tracked company admin endpoints and scheduler integration.
Covers issue #1656:
- GET /admin/tracked (list tracked companies)
- POST /admin/tracked (add a tracked company)
- DELETE /admin/tracked/{company_name} (remove a tracked company)
- GET /admin/alerts (list alerts)
- scheduler.run_scheduled_analysis() integration
All tests mock the database layer and use JWT auth fixtures.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch, call
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock the database client used by admin and auth endpoints."""
db = MagicMock()
# Default admin user for auth
db.get_user_by_id.return_value = {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _admin_header():
"""Create an Authorization header with a valid admin access token."""
token = create_access_token(1, "admin@test.com", "admin")
return {"Authorization": f"Bearer {token}"}
def _user_header():
"""Create an Authorization header with a regular user access token."""
token = create_access_token(2, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
# ---------- GET /admin/tracked ----------
class TestListTrackedCompanies:
"""GET /admin/tracked"""
def test_list_tracked_returns_companies(self, client, mock_db):
"""Admin can list tracked companies."""
mock_db.list_tracked_companies.return_value = [
{"company_name": "NVIDIA", "last_patent_count": 120, "last_analyzed": "2025-06-15"},
{"company_name": "AMD", "last_patent_count": 80, "last_analyzed": "2025-06-14"},
]
response = client.get("/admin/tracked", headers=_admin_header())
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["company_name"] == "NVIDIA"
def test_list_tracked_empty(self, client, mock_db):
"""Returns empty list when no companies are tracked."""
mock_db.list_tracked_companies.return_value = []
response = client.get("/admin/tracked", headers=_admin_header())
assert response.status_code == 200
assert response.json() == []
def test_list_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot access tracked companies list."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.get("/admin/tracked", headers=_user_header())
assert response.status_code == 403
def test_list_tracked_unauthenticated(self, client):
"""Unauthenticated request returns 401."""
response = client.get("/admin/tracked")
assert response.status_code == 401
# ---------- POST /admin/tracked ----------
class TestAddTrackedCompany:
"""POST /admin/tracked"""
def test_add_tracked_company_success(self, client, mock_db):
"""Admin can add a company to tracking."""
mock_db.add_tracked_company.return_value = {
"company_name": "Intel",
"last_patent_count": 0,
"last_analyzed": None,
}
response = client.post(
"/admin/tracked",
json={"company_name": "Intel"},
headers=_admin_header(),
)
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "Intel"
mock_db.add_tracked_company.assert_called_once_with("Intel")
def test_add_duplicate_returns_409(self, client, mock_db):
"""Adding an already-tracked company returns 409."""
mock_db.add_tracked_company.return_value = None
response = client.post(
"/admin/tracked",
json={"company_name": "NVIDIA"},
headers=_admin_header(),
)
assert response.status_code == 409
assert "already tracked" in response.json()["detail"].lower()
def test_add_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot add tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.post(
"/admin/tracked",
json={"company_name": "Intel"},
headers=_user_header(),
)
assert response.status_code == 403
def test_add_tracked_empty_name_rejected(self, client):
"""Empty company name is rejected by validation."""
response = client.post(
"/admin/tracked",
json={"company_name": ""},
headers=_admin_header(),
)
assert response.status_code == 422 # Pydantic validation error
# ---------- DELETE /admin/tracked/{company_name} ----------
class TestRemoveTrackedCompany:
"""DELETE /admin/tracked/{company_name}"""
def test_remove_tracked_company_success(self, client, mock_db):
"""Admin can remove a tracked company."""
mock_db.remove_tracked_company.return_value = True
response = client.delete(
"/admin/tracked/NVIDIA",
headers=_admin_header(),
)
assert response.status_code == 200
assert "Stopped tracking" in response.json()["message"]
mock_db.remove_tracked_company.assert_called_once_with("NVIDIA")
def test_remove_nonexistent_returns_404(self, client, mock_db):
"""Removing a non-tracked company returns 404."""
mock_db.remove_tracked_company.return_value = False
response = client.delete(
"/admin/tracked/UnknownCorp",
headers=_admin_header(),
)
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_remove_tracked_requires_admin(self, client, mock_db):
"""Regular user cannot remove tracked companies."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.delete(
"/admin/tracked/NVIDIA",
headers=_user_header(),
)
assert response.status_code == 403
# ---------- GET /admin/alerts ----------
class TestListAlerts:
"""GET /admin/alerts"""
def test_list_alerts_returns_data(self, client, mock_db):
"""Admin can list alerts."""
mock_db.list_alerts.return_value = [
{
"id": 1,
"company_name": "NVIDIA",
"alert_type": "patent_count_change",
"message": "Patent count increased by 25%",
"created_at": "2025-06-15T10:00:00Z",
},
]
response = client.get("/admin/alerts", headers=_admin_header())
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["alert_type"] == "patent_count_change"
def test_list_alerts_with_limit(self, client, mock_db):
"""Custom limit parameter is passed to the database."""
mock_db.list_alerts.return_value = []
response = client.get("/admin/alerts?limit=10", headers=_admin_header())
assert response.status_code == 200
mock_db.list_alerts.assert_called_once_with(limit=10)
def test_list_alerts_requires_admin(self, client, mock_db):
"""Regular user cannot access alerts."""
mock_db.get_user_by_id.return_value = {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.get("/admin/alerts", headers=_user_header())
assert response.status_code == 403
# ---------- Scheduler integration ----------
class TestSchedulerIntegration:
"""Tests for scheduler.run_scheduled_analysis()."""
def test_no_tracked_companies_skips_analysis(self):
"""Scheduler does nothing when no companies are tracked."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = []
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer") as mock_analyzer_cls:
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
mock_analyzer_cls.assert_not_called()
def test_scheduler_analyzes_each_tracked_company(self):
"""Scheduler runs analysis for every tracked company."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = [
{"company_name": "NVIDIA", "last_patent_count": 100},
{"company_name": "AMD", "last_patent_count": 50},
]
mock_result_nvidia = MagicMock(success=True, patent_count=110)
mock_result_amd = MagicMock(success=True, patent_count=55)
mock_analyzer = MagicMock()
mock_analyzer._analyze_company_safe.side_effect = [mock_result_nvidia, mock_result_amd]
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
assert mock_analyzer._analyze_company_safe.call_count == 2
mock_db.update_tracked_company.assert_any_call("NVIDIA", 110)
mock_db.update_tracked_company.assert_any_call("AMD", 55)
def test_scheduler_triggers_alert_on_significant_change(self):
"""Scheduler stores an alert when patent count changes significantly."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = [
{"company_name": "Tesla", "last_patent_count": 100},
]
mock_result = MagicMock(success=True, patent_count=130) # 30% increase
mock_analyzer = MagicMock()
mock_analyzer._analyze_company_safe.return_value = mock_result
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
mock_db.store_alert.assert_called_once()
alert_kwargs = mock_db.store_alert.call_args
assert alert_kwargs[1]["company_name"] == "Tesla"
assert alert_kwargs[1]["alert_type"] == "patent_count_change"
assert alert_kwargs[1]["old_value"] == 100
assert alert_kwargs[1]["new_value"] == 130
def test_scheduler_no_alert_for_small_change(self):
"""Scheduler does not alert when change is below threshold."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = [
{"company_name": "Intel", "last_patent_count": 100},
]
mock_result = MagicMock(success=True, patent_count=105) # 5% increase
mock_analyzer = MagicMock()
mock_analyzer._analyze_company_safe.return_value = mock_result
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
mock_db.store_alert.assert_not_called()
def test_scheduler_handles_analysis_failure(self):
"""Scheduler continues when one company fails analysis."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = [
{"company_name": "FailCo", "last_patent_count": 50},
{"company_name": "SuccessCo", "last_patent_count": 30},
]
mock_fail_result = MagicMock(success=False, error="API timeout")
mock_ok_result = MagicMock(success=True, patent_count=35)
mock_analyzer = MagicMock()
mock_analyzer._analyze_company_safe.side_effect = [mock_fail_result, mock_ok_result]
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
# FailCo should not get updated, SuccessCo should
mock_db.update_tracked_company.assert_called_once_with("SuccessCo", 35)
def test_scheduler_handles_exception_in_analysis(self):
"""Scheduler continues even when analysis raises an exception."""
mock_db = MagicMock()
mock_db.list_tracked_companies.return_value = [
{"company_name": "CrashCo", "last_patent_count": 10},
{"company_name": "OKCo", "last_patent_count": 20},
]
mock_ok_result = MagicMock(success=True, patent_count=22)
mock_analyzer = MagicMock()
mock_analyzer._analyze_company_safe.side_effect = [
RuntimeError("unexpected error"),
mock_ok_result,
]
with patch("SPARC.scheduler.get_db_client", return_value=mock_db), \
patch("SPARC.scheduler.CompanyAnalyzer", return_value=mock_analyzer):
from SPARC.scheduler import run_scheduled_analysis
run_scheduled_analysis()
# OKCo should still be processed
mock_db.update_tracked_company.assert_called_once_with("OKCo", 22)
+280
View File
@@ -0,0 +1,280 @@
"""Tests for webhook notification system: retry logic and Slack/Discord payload format.
Covers issue #1657:
- Retry logic with exponential backoff in _send_with_retry
- Slack/Discord payload formatting in _build_payload
- Generic HTTP POST payload formatting
- notify() dispatching to multiple URLs
- notify_job_completed() and notify_alert() convenience helpers
"""
from datetime import datetime
from unittest.mock import MagicMock, patch, call
import pytest
import requests
from SPARC.webhooks import (
MAX_RETRIES,
_build_payload,
_is_slack_url,
_send_with_retry,
notify,
notify_alert,
notify_job_completed,
)
class TestIsSlackUrl:
"""Tests for Slack/Discord URL detection."""
def test_slack_webhook_url(self):
assert _is_slack_url("https://hooks.slack.com/services/T00/B00/xxx") is True
def test_discord_webhook_url(self):
assert _is_slack_url("https://discord.com/api/webhooks/123/abc") is True
def test_generic_url(self):
assert _is_slack_url("https://example.com/webhook") is False
def test_empty_url(self):
assert _is_slack_url("") is False
class TestBuildPayload:
"""Tests for payload construction."""
def test_generic_payload_structure(self):
"""Generic payload includes event type, timestamp, and data."""
payload = _build_payload("job_completed", {"job_id": "abc123"})
assert payload["event"] == "job_completed"
assert payload["job_id"] == "abc123"
assert "timestamp" in payload
# Timestamp should be ISO format ending with Z
assert payload["timestamp"].endswith("Z")
def test_slack_payload_wraps_in_text(self):
"""Slack payload wraps content in a 'text' field."""
payload = _build_payload("patent_alert", {"company_name": "NVIDIA"}, slack=True)
assert "text" in payload
assert "patent_alert" in payload["text"]
assert "NVIDIA" in payload["text"]
# Slack payload should NOT have the event/timestamp at top level
assert "event" not in payload
assert "timestamp" not in payload
def test_generic_payload_does_not_have_text_field(self):
"""Non-Slack payload does not wrap in text."""
payload = _build_payload("job_completed", {"status": "done"})
assert "text" not in payload
assert payload["status"] == "done"
def test_slack_payload_contains_bold_header(self):
"""Slack payload starts with bold event header using Slack markdown."""
payload = _build_payload("job_completed", {"count": 5}, slack=True)
assert payload["text"].startswith("*[SPARC] job_completed*")
def test_payload_merges_all_data_keys(self):
"""All data keys are included in the generic payload."""
data = {"key1": "val1", "key2": 42, "key3": True}
payload = _build_payload("test_event", data)
assert payload["key1"] == "val1"
assert payload["key2"] == 42
assert payload["key3"] is True
class TestSendWithRetry:
"""Tests for retry logic in _send_with_retry."""
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_success_on_first_attempt(self, mock_post, mock_sleep):
"""Successful delivery on first attempt, no retries."""
mock_post.return_value = MagicMock(status_code=200)
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is True
mock_post.assert_called_once()
mock_sleep.assert_not_called()
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_success_on_second_attempt(self, mock_post, mock_sleep):
"""Fails first, succeeds on retry."""
mock_post.side_effect = [
MagicMock(status_code=500),
MagicMock(status_code=200),
]
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is True
assert mock_post.call_count == 2
mock_sleep.assert_called_once()
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_all_retries_exhausted(self, mock_post, mock_sleep):
"""Returns False after all retries fail."""
mock_post.return_value = MagicMock(status_code=500)
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is False
assert mock_post.call_count == MAX_RETRIES
assert mock_sleep.call_count == MAX_RETRIES - 1
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_exponential_backoff_timing(self, mock_post, mock_sleep):
"""Backoff wait times follow exponential pattern (2^attempt)."""
mock_post.return_value = MagicMock(status_code=500)
_send_with_retry("https://example.com/hook", {"event": "test"})
# With BACKOFF_BASE=2: attempt 1 -> sleep(2), attempt 2 -> sleep(4)
expected_waits = [call(2 ** i) for i in range(1, MAX_RETRIES)]
assert mock_sleep.call_args_list == expected_waits
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_network_error_triggers_retry(self, mock_post, mock_sleep):
"""Network exceptions trigger retry, not immediate failure."""
mock_post.side_effect = [
requests.ConnectionError("Connection refused"),
MagicMock(status_code=200),
]
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is True
assert mock_post.call_count == 2
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_timeout_error_triggers_retry(self, mock_post, mock_sleep):
"""Timeout exceptions trigger retry."""
mock_post.side_effect = [
requests.Timeout("Request timed out"),
MagicMock(status_code=200),
]
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is True
assert mock_post.call_count == 2
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_2xx_status_codes_accepted(self, mock_post, mock_sleep):
"""Any 2xx status code is treated as success."""
mock_post.return_value = MagicMock(status_code=204)
result = _send_with_retry("https://example.com/hook", {"event": "test"})
assert result is True
mock_post.assert_called_once()
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_posts_json_payload(self, mock_post, mock_sleep):
"""Payload is sent as JSON with correct timeout."""
mock_post.return_value = MagicMock(status_code=200)
payload = {"event": "test", "data": "value"}
_send_with_retry("https://example.com/hook", payload)
mock_post.assert_called_once_with(
"https://example.com/hook", json=payload, timeout=10
)
class TestNotify:
"""Tests for the notify() dispatcher."""
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook1", "https://example.com/hook2"])
def test_dispatches_to_all_urls(self, mock_send):
"""notify() sends to every configured webhook URL."""
mock_send.return_value = True
notify("job_completed", {"job_id": "test123"})
assert mock_send.call_count == 2
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [])
def test_no_urls_configured_returns_immediately(self, mock_send):
"""No-op when no webhook URLs are configured."""
notify("job_completed", {"job_id": "test123"})
mock_send.assert_not_called()
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [
"https://hooks.slack.com/services/T00/B00/xxx",
"https://example.com/generic",
])
def test_slack_url_gets_slack_payload(self, mock_send):
"""Slack URLs receive Slack-formatted payloads, others get generic."""
mock_send.return_value = True
notify("test_event", {"key": "val"})
# First call (Slack URL) should have "text" key
slack_payload = mock_send.call_args_list[0][0][1]
assert "text" in slack_payload
# Second call (generic URL) should have "event" key
generic_payload = mock_send.call_args_list[1][0][1]
assert "event" in generic_payload
assert generic_payload["event"] == "test_event"
class TestNotifyJobCompleted:
"""Tests for notify_job_completed() convenience function."""
@patch("SPARC.webhooks.notify")
def test_sends_correct_event_and_data(self, mock_notify):
"""Job completion sends proper event type and summary."""
notify_job_completed(
job_id="batch-001",
status="completed",
total_companies=10,
successful=8,
failed=2,
)
mock_notify.assert_called_once()
event, data = mock_notify.call_args[0]
assert event == "job_completed"
assert data["job_id"] == "batch-001"
assert data["successful"] == 8
assert data["failed"] == 2
assert "8/10" in data["summary"]
class TestNotifyAlert:
"""Tests for notify_alert() convenience function."""
@patch("SPARC.webhooks.notify")
def test_sends_correct_event_and_data(self, mock_notify):
"""Alert notification sends patent_alert event type."""
notify_alert(
company_name="NVIDIA",
alert_type="patent_count_change",
message="Patent count increased by 30%",
)
mock_notify.assert_called_once()
event, data = mock_notify.call_args[0]
assert event == "patent_alert"
assert data["company_name"] == "NVIDIA"
assert data["alert_type"] == "patent_count_change"
assert "30%" in data["message"]