diff --git a/.gitea/workflows/build.yaml b/.gitea/workflows/build.yaml index 80acc27..aeeb5c1 100644 --- a/.gitea/workflows/build.yaml +++ b/.gitea/workflows/build.yaml @@ -28,10 +28,10 @@ jobs: run: | pip3 install -r requirements.txt ruff -# - name: Run ruff linter -# shell: sh -# run: | -# ruff check SPARC/ tests/ + - name: Run ruff linter + shell: sh + run: | + ruff check SPARC/ tests/ - name: Install Node.js and check TypeScript types shell: sh @@ -47,16 +47,17 @@ jobs: fi npx tsc --noEmit -# - name: Run pytest -# shell: sh -# env: -# DATABASE_URL: "sqlite://" -# API_KEY: "test-key" -# OPENROUTER_API_KEY: "test-key" -# JWT_SECRET: "test-secret-for-ci" -# APP_ENV: "development" -# run: | -# python3 -m pytest tests/ -v --tb=short -x + - name: Run pytest + shell: sh + env: + DATABASE_URL: "sqlite://" + API_KEY: "test-key" + OPENROUTER_API_KEY: "test-key" + JWT_SECRET: "test-secret-for-ci" + APP_ENV: "development" + run: | + pip3 install pytest + python3 -m pytest tests/ -v --tb=short -x build-api: needs: test diff --git a/ROADMAP.md b/ROADMAP.md index 42b571a..0e86ab5 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -7,86 +7,124 @@ Semiconductor Patent & Analytics Report Core -- development priorities. SPARC is a patent analysis platform with a working end-to-end pipeline: Python/FastAPI backend, React/TypeScript frontend, PostgreSQL for persistence and caching, Docker Compose for local development, and Gitea Actions CI/CD for -image builds. Core features (patent retrieval via SerpAPI, PDF parsing, LLM -analysis via OpenRouter/Claude, batch processing, JWT authentication, analytics -dashboard) are all implemented and functional. +image builds and testing. Core features include patent retrieval via SerpAPI, +PDF parsing, LLM analysis via OpenRouter (multi-model: Claude, GPT-4o, Gemini, +Llama), batch processing, JWT authentication, analytics dashboard with patent +trend charts, scheduled recurring analysis with alerting, webhook notifications +(Slack/Discord), CSV and PDF export, S3/MinIO storage, side-by-side company +comparison, and dark mode. + +--- + +## Completed + +Items that have been implemented and merged into main. + +### Security hardening + +- ~~Rotate default JWT secret.~~ Startup check refuses to start with the + default secret in non-development environments. +- ~~CORS allow-origins are hardcoded.~~ Allowed origins are now configurable + via environment variable. +- ~~Database credentials in docker-compose.yml.~~ Compose references `.env` + for sensitive values. + +### Error handling and resilience + +- ~~`get_db_client()` creates a new `DatabaseClient` on every call.~~ Refactored + to a shared pooled singleton initialized at startup. +- ~~No rate limiting on auth endpoints.~~ Rate limiting middleware added to + `/auth/login` and `/auth/register`. + +### Test coverage + +- ~~API tests bypass authentication.~~ JWT auth integration tests added (33 + cases covering registration, login, protected routes, token refresh, and + admin-only endpoints). +- ~~No test stage in CI.~~ Gitea Actions workflow now runs `pytest` and gates + the build. +- ~~No linting or type checking in CI.~~ `ruff` (Python) and `tsc --noEmit` + (TypeScript) added to CI pipeline. + +### Backend + +- ~~Add structured logging.~~ Python `logging` module used throughout. +- ~~Make LLM model configurable.~~ `MODEL` environment variable accepted; + multi-model support with per-analysis selection (GPT-4o, Gemini, Claude, + Llama). +- ~~SERP cache TTL hardcoded.~~ `SERP_CACHE_TTL_HOURS` exposed as env var. +- ~~Patent PDF storage.~~ S3/MinIO object storage backend added alongside + local filesystem. Volume mount requirement documented. +- ~~`analyze_single_patent` assumes local file.~~ Auto-download from cached + metadata link integrated. +- ~~`Patent.patent_id` typed as `int`.~~ Fixed to `str`. + +### Frontend + +- ~~No loading/error states.~~ Skeleton loaders and error states added to + Batch and Analytics pages. +- ~~No dark mode.~~ Full dark mode support with theme-aware chart colors. +- ~~Missing lockfile.~~ `package-lock.json` committed. + +### Features (formerly P3) + +- ~~Export analysis reports.~~ CSV and PDF export endpoints implemented. +- ~~Comparison view.~~ Side-by-side company patent portfolio comparison added. +- ~~Scheduled/recurring analysis.~~ APScheduler-based periodic re-analysis + with configurable interval and change-threshold alerting. +- ~~Webhook/notification support.~~ Slack, Discord, and generic HTTP POST + webhooks with retry logic. +- ~~Multi-model support.~~ Model picker in Analysis and Batch pages; backend + allow-list validation. +- ~~Patent trend charts.~~ Filing frequency and category distribution + visualizations added to Analytics page. +- ~~OpenAPI client generation.~~ TypeScript API client auto-generated from + FastAPI spec with CI freshness check. + +### Resilience + +- ~~`_jobs` dict is in-memory only.~~ Database-backed job persistence + implemented using `db.list_jobs()` and `mark_stale_jobs_failed()`. The + in-memory `_jobs` dict has been removed. + +### Test coverage (P1/P2) + +- ~~Export endpoint tests.~~ Tests added for CSV and PDF export endpoints. +- ~~Tracked company admin endpoint tests.~~ Tests added for `/admin/tracked` + CRUD endpoints and scheduler integration. +- ~~Webhook integration tests.~~ Tests added for retry logic, Slack/Discord + payload format, and multi-URL dispatch. +- ~~S3/MinIO storage backend tests.~~ Unit tests added for the S3 backend + (read, write, exists, delete, error handling). +- ~~`analyze_single_patent` auto-download path tests.~~ Tests added for the + auto-download fallback (cache lookup, PDF download, FileNotFoundError). + +### Code quality + +- ~~Scheduler creates its own DatabaseClient.~~ Refactored to use the + application-level pooled `get_db_client()`. --- ## P1 -- High Priority -These items address correctness, security, and reliability gaps that should be -resolved before broader production use. - -### Security hardening - -- **Rotate default JWT secret.** `auth.py` ships a fallback - `sparc-secret-key-change-in-production` that will be used if `JWT_SECRET` is - unset. Add a startup check that refuses to start with the default secret in - non-development environments. -- **CORS allow-origins are hardcoded.** `api.py` only permits - `localhost:3000` and `localhost:5173`. Make the allowed origins configurable - via environment variable so the dashboard works when deployed behind a real - domain. -- **Database credentials in docker-compose.yml.** The compose file embeds - `postgres:postgres` in plain text. Reference a `.env` file or Docker secrets - instead. - -### Error handling and resilience - -- **`get_db_client()` in `auth.py` creates a new `DatabaseClient` on every - call.** This bypasses the connection pool and can exhaust database - connections under load. Refactor to share a single pooled client. -- **`_jobs` dict is in-memory only.** Job state is lost on API restart. Persist - job status in PostgreSQL or Redis so async batch results survive restarts. -- **No rate limiting on auth endpoints.** `/auth/login` and `/auth/register` - are unprotected against brute-force or abuse. Add rate limiting middleware. - -### Test coverage for auth and admin - -- The existing API tests (`tests/test_api.py`) bypass authentication entirely. - Add tests that exercise the JWT flow: registration, login, protected-route - access, token refresh, and admin-only endpoints. +No outstanding P1 items. All previously listed items have been completed and +moved to the Completed section above. --- ## P2 -- Medium Priority -Improvements to usability, performance, and developer experience. +Improvements to the API surface. -### Backend +### API improvements -- **Add structured logging.** Replace `print()` calls throughout `analyzer.py`, - `serp_api.py`, and `llm.py` with Python `logging` so log levels and - formatting are consistent. -- **Make LLM model configurable.** `llm.py` hardcodes - `anthropic/claude-3.5-sonnet`. Accept a `MODEL` environment variable to allow - switching models without code changes. -- **SERP cache TTL is hardcoded to 24 hours.** Expose `SERP_CACHE_TTL_HOURS` - as an environment variable in `config.py`. -- **Patent PDF storage.** PDFs are saved to a local `patents/` directory. For - containerized deployments, consider object storage (S3/MinIO) or at minimum - document the volume mount requirement more prominently. -- **`analyze_single_patent` assumes local file path.** The method constructs - `patents/{patent_id}.pdf` and reads from disk, but does not download the PDF - first. Either integrate the download step or document the prerequisite. -- **`Patent.patent_id` typed as `int` in `types.py` but used as `str` - everywhere.** Fix the type annotation to `str`. - -### Frontend - -- **No loading/error states on several pages.** The Batch and Analytics pages - would benefit from skeleton loaders and user-friendly error messages. -- **No dark mode.** Tailwind is configured but no dark variant is applied. -- **Missing `package-lock.json` or `pnpm-lock.yaml`.** The frontend has no - lockfile committed, leading to non-reproducible builds. - -### CI/CD - -- **No test stage in the Gitea Actions workflow.** `build.yaml` builds and - pushes images but never runs `pytest`. Add a test job that gates the build. -- **No linting or type checking.** Add `ruff` (Python) and `tsc --noEmit` - (TypeScript) to CI. +- **API pagination.** The `/analyze/batch` endpoint needs cursor-based + pagination for large result sets. The `/jobs` endpoint already has cursor + pagination. *(Issue #1669)* +- **Request validation improvements.** Add stricter input validation for + company names (disallow special characters, enforce length limits). + *(Issue #1670)* --- @@ -94,23 +132,20 @@ Improvements to usability, performance, and developer experience. Lower-urgency enhancements and future features. -- **Export analysis reports.** Allow users to download analysis results as PDF - or CSV from the dashboard. -- **Comparison view.** Side-by-side comparison of two companies' patent - portfolios. -- **Scheduled/recurring analysis.** Periodically re-analyze tracked companies - and alert on significant changes. -- **Webhook/notification support.** Send alerts (Slack, Discord, email) when - batch jobs complete or when a company's innovation score changes - significantly. -- **Multi-model support.** Let users choose between LLM providers per analysis - (e.g., GPT-4o, Gemini, Claude) and compare outputs. -- **Patent trend charts.** Visualize patent filing frequency and technology - category distribution over time in the Analytics page. -- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit - from cursor-based pagination for large result sets. -- **OpenAPI client generation.** Auto-generate the TypeScript API client from - the FastAPI OpenAPI spec to keep frontend types in sync. +- **Historical analysis diffing.** Show what changed between two analysis runs + for the same company, highlighting new patents and score shifts. +- **Patent classification tagging.** Automatically tag patents by technology + domain (AI, semiconductors, materials science) using LLM classification. +- **User-level API keys.** Allow users to generate personal API keys for + programmatic access without JWT token refresh. +- **Batch export.** Export analysis results for multiple companies at once as + a ZIP archive. +- **Rate limiting dashboard.** Surface rate limit status and usage statistics + in the admin panel. +- **Async webhook delivery.** Move webhook delivery to a background task queue + (e.g., Celery, arq) to avoid blocking the scheduler. +- **Multi-tenant support.** Scope analysis results and tracked companies per + user or organization. --- diff --git a/SPARC/analyzer.py b/SPARC/analyzer.py index 31ad7f1..1ebceaf 100644 --- a/SPARC/analyzer.py +++ b/SPARC/analyzer.py @@ -10,13 +10,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Callable from SPARC import config - -logger = logging.getLogger(__name__) from SPARC.database import DatabaseClient from SPARC.llm import LLMAnalyzer from SPARC.serp_api import SERP from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents +logger = logging.getLogger(__name__) + class CompanyAnalyzer: """Orchestrates end-to-end company performance analysis via patents.""" diff --git a/SPARC/api.py b/SPARC/api.py index 3a28033..1b29d38 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -3,14 +3,19 @@ Provides REST API endpoints for analyzing company patent portfolios. """ +from __future__ import annotations + from contextlib import asynccontextmanager from datetime import datetime -from typing import Annotated, List +from typing import TYPE_CHECKING, Annotated, List -from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request +if TYPE_CHECKING: + from SPARC.database import DatabaseClient + +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel, EmailStr, Field, StringConstraints from slowapi import Limiter from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address @@ -31,6 +36,16 @@ from SPARC.auth import ( ) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult +# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only. +CompanyName = Annotated[ + str, + StringConstraints( + min_length=2, + max_length=100, + pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", + ), +] + # Pydantic models for API class CompanyAnalysisResponse(BaseModel): @@ -67,7 +82,7 @@ class CompanyAnalysisRequest(BaseModel): class BatchAnalysisRequest(BaseModel): """Request model for batch company analysis.""" - companies: list[str] = Field( + companies: list[CompanyName] = Field( ..., min_length=1, max_length=20, description="List of company names to analyze" ) max_workers: int = Field( @@ -91,6 +106,24 @@ class JobStatus(BaseModel): error: str | None = None +class AnalysisRecord(BaseModel): + """A single stored analysis result.""" + + id: int + company_name: str | None = None + analysis_type: str | None = None + model: str | None = None + response: str | None = None + timestamp: datetime | None = None + + +class PaginatedAnalysisResponse(BaseModel): + """Paginated response for analysis result listings.""" + + items: list[AnalysisRecord] + next_cursor: str | None = None + + class PaginatedJobsResponse(BaseModel): """Paginated response for job listings.""" @@ -212,10 +245,37 @@ app = FastAPI( limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter +# In-memory rate limit statistics +_rate_limit_stats: dict[str, dict] = {} + + +def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: + """Record a request against a rate-limited endpoint.""" + key = endpoint + if key not in _rate_limit_stats: + _rate_limit_stats[key] = { + "endpoint": endpoint, + "total_requests": 0, + "rejected_requests": 0, + "by_ip": {}, + } + _rate_limit_stats[key]["total_requests"] += 1 + if rejected: + _rate_limit_stats[key]["rejected_requests"] += 1 + ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) + if ip not in ip_stats: + ip_stats[ip] = {"total": 0, "rejected": 0} + ip_stats[ip]["total"] += 1 + if rejected: + ip_stats[ip]["rejected"] += 1 + @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(request: Request, exc: RateLimitExceeded): """Return 429 with Retry-After header when rate limit is exceeded.""" + endpoint = request.url.path + ip = get_remote_address(request) + _track_rate_limit_request(endpoint, ip, rejected=True) retry_after = getattr(exc, "retry_after", 60) return JSONResponse( status_code=429, @@ -244,6 +304,7 @@ async def register(request: Request, body: RegisterRequest): The first registered user automatically becomes an admin. """ + _track_rate_limit_request("/auth/register", get_remote_address(request)) db = get_db_client() # First user becomes admin @@ -274,6 +335,7 @@ async def register(request: Request, body: RegisterRequest): @limiter.limit("10/minute") async def login(request: Request, body: LoginRequest): """Authenticate user and return JWT tokens.""" + _track_rate_limit_request("/auth/login", get_remote_address(request)) db = get_db_client() user = db.authenticate_user(body.email, body.password) @@ -400,7 +462,7 @@ async def delete_user( class TrackCompanyRequest(BaseModel): """Request to add a company to tracking.""" - company_name: str = Field(..., min_length=1, max_length=255) + company_name: CompanyName = Field(...) @app.get("/admin/tracked", tags=["Admin"]) @@ -427,7 +489,7 @@ async def add_tracked_company( @app.delete("/admin/tracked/{company_name}", tags=["Admin"]) async def remove_tracked_company( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_admin), ): """Remove a company from the tracked list (admin only).""" @@ -438,6 +500,36 @@ async def remove_tracked_company( return {"message": f"Stopped tracking {company_name}"} +@app.get("/admin/rate-limits", tags=["Admin"]) +async def get_rate_limit_stats( + _: UserResponse = Depends(get_current_admin), +): + """Get rate limit status and usage statistics (admin only). + + Returns current rate limit configuration and request statistics + for all rate-limited endpoints. + + Returns: + List of rate limit stats per endpoint with total/rejected counts + """ + rate_limits_config = { + "/auth/register": {"limit": "5/minute"}, + "/auth/login": {"limit": "10/minute"}, + } + + results = [] + for endpoint, conf in rate_limits_config.items(): + stats = _rate_limit_stats.get(endpoint, {}) + results.append({ + "endpoint": endpoint, + "limit": conf["limit"], + "total_requests": stats.get("total_requests", 0), + "rejected_requests": stats.get("rejected_requests", 0), + }) + + return {"rate_limits": results} + + @app.get("/admin/alerts", tags=["Admin"]) async def list_alerts( limit: int = Query(default=50, ge=1, le=200), @@ -585,7 +677,7 @@ async def get_analytics_trends( @app.get("/export/{company_name}", tags=["Export"]) async def export_company_csv( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_user), ): """Export analysis results for a company as a CSV file. @@ -637,7 +729,7 @@ async def export_company_csv( @app.get("/export/{company_name}/pdf", tags=["Export"]) async def export_company_pdf( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_user), ): """Export analysis results for a company as a formatted PDF report. @@ -653,7 +745,6 @@ async def export_company_pdf( PDF file download """ import io - import textwrap from reportlab.lib import colors from reportlab.lib.pagesizes import letter @@ -812,7 +903,7 @@ async def health_check(): tags=["Analysis"], ) async def analyze_company( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), _: UserResponse = Depends(get_current_user), ): @@ -842,7 +933,7 @@ async def analyze_company( ) async def analyze_single_patent( patent_id: str, - company_name: str = Query(description="Company name for analysis context"), + company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], _: UserResponse = Depends(get_current_user), ): """Analyze a single patent by its publication ID. @@ -868,6 +959,58 @@ async def analyze_single_patent( raise HTTPException(status_code=404, detail=str(e)) +@app.get( + "/analyze/batch", + response_model=PaginatedAnalysisResponse, + tags=["Analysis"], +) +async def list_analysis_results( + company_name: Annotated[ + str | None, + Query(description="Filter results by company name"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_user), +): + """List stored analysis results with cursor-based pagination. + + Returns past analysis results ordered by timestamp descending. Use + ``limit`` to control page size (default 50, max 200). The response + includes a ``next_cursor`` field; pass it back as the ``cursor`` query + parameter to fetch the next page. When ``next_cursor`` is ``null``, + there are no more results. + + Args: + company_name: Optional filter by company name + limit: Maximum number of results to return (default 50, max 200) + cursor: Opaque pagination cursor from a previous response + + Returns: + Paginated list of analysis results + """ + db = _get_job_db() + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + + has_next = len(rows) > limit + if has_next: + rows = rows[:limit] + + items = [AnalysisRecord(**row) for row in rows] + + next_cursor = None + if has_next and rows: + last = rows[-1] + ts = last["timestamp"] + ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts) + next_cursor = f"{ts_str}|{last['id']}" + + return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor) + + @app.post( "/analyze/batch", response_model=BatchAnalysisResponse, @@ -1043,7 +1186,7 @@ async def list_jobs( str | None, Query(description="Filter by status: pending, running, completed, failed"), ] = None, - limit: Annotated[int, Query(ge=1, le=100)] = 10, + limit: Annotated[int, Query(ge=1, le=200)] = 50, cursor: Annotated[ str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), diff --git a/SPARC/database.py b/SPARC/database.py index 24c7081..0759a66 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -371,6 +371,48 @@ class DatabaseClient: cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] + def list_analyses( + self, + company_name: Optional[str] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> List[Dict]: + """List analysis results with cursor-based pagination. + + Args: + company_name: Optional filter by company name. + limit: Maximum number of records to return. + cursor: Opaque cursor (``timestamp|id``) from a previous response. + + Returns: + List of analysis dicts ordered by timestamp descending. + """ + conditions: list[str] = ["is_cached = FALSE"] + params: list = [] + + if company_name: + conditions.append("LOWER(company_name) = LOWER(%s)") + params.append(company_name) + + if cursor: + try: + ts_str, cursor_id = cursor.rsplit("|", 1) + conditions.append("(timestamp, id) < (%s, %s)") + params.extend([ts_str, int(cursor_id)]) + except (ValueError, TypeError): + pass # Ignore malformed cursors; return from start + + query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages" + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " ORDER BY timestamp DESC, id DESC LIMIT %s" + params.append(limit) + + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(query, params) + return [dict(row) for row in cur.fetchall()] + def get_analytics(self, days: int = 30) -> Dict: """Get analytics on message usage. diff --git a/SPARC/scheduler.py b/SPARC/scheduler.py index 5af3940..4428bfd 100644 --- a/SPARC/scheduler.py +++ b/SPARC/scheduler.py @@ -2,14 +2,17 @@ Uses APScheduler to periodically re-analyze tracked companies and detect significant changes in patent counts. + +The scheduler reuses the application-level pooled DatabaseClient +(from ``SPARC.auth``) instead of creating its own connection, which +avoids exhausting the database connection pool under load. """ import logging import os -from SPARC import config from SPARC.analyzer import CompanyAnalyzer -from SPARC.database import DatabaseClient +from SPARC.auth import get_db_client logger = logging.getLogger(__name__) @@ -21,10 +24,13 @@ CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20")) def run_scheduled_analysis() -> None: - """Re-analyze all tracked companies and check for significant changes.""" - db = DatabaseClient(config.database_url) - db.connect() - db.initialize_schema() + """Re-analyze all tracked companies and check for significant changes. + + Uses the shared pooled DatabaseClient from ``SPARC.auth.get_db_client()`` + rather than creating a disposable connection, so the scheduler participates + in the same connection pool as the rest of the application. + """ + db = get_db_client() tracked = db.list_tracked_companies() if not tracked: @@ -74,7 +80,6 @@ def run_scheduled_analysis() -> None: except Exception as e: logger.error("Error analyzing tracked company %s: %s", name, e) - db.close() logger.info("Scheduled analysis complete") diff --git a/frontend/src/pages/Analysis.tsx b/frontend/src/pages/Analysis.tsx index 7ec67f7..2f4fc35 100644 --- a/frontend/src/pages/Analysis.tsx +++ b/frontend/src/pages/Analysis.tsx @@ -159,7 +159,7 @@ export function Analysis() { -
+
{result.analysis}
diff --git a/tests/test_analyze_single_patent.py b/tests/test_analyze_single_patent.py new file mode 100644 index 0000000..3b2283b --- /dev/null +++ b/tests/test_analyze_single_patent.py @@ -0,0 +1,211 @@ +"""Tests for analyze_single_patent auto-download path. + +Covers issue #1661: +- PDF exists on disk: direct analysis (happy path) +- PDF not on disk, cached link exists: auto-download and analyze +- PDF not on disk, no cached link: FileNotFoundError +- Analysis failure after PDF found: graceful error message +- Model override parameter passthrough +""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from SPARC.analyzer import CompanyAnalyzer +from SPARC.types import Patent + + +@pytest.fixture(autouse=True) +def mock_db(mocker): + """Mock DatabaseClient so no real DB is needed.""" + mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient") + mock_db_instance = MagicMock() + mock_db_instance.get_cached_patent.return_value = None + mock_db_instance.get_cached_serp_query.return_value = None + mock_db_cls.return_value = mock_db_instance + return mock_db_instance + + +@pytest.fixture +def analyzer(mocker, mock_db): + """Create a CompanyAnalyzer with mocked LLM and DB.""" + mocker.patch("SPARC.analyzer.LLMAnalyzer") + return CompanyAnalyzer(openrouter_api_key="test-key") + + +class TestAnalyzeSinglePatentAutoDownload: + """Test the auto-download logic in analyze_single_patent.""" + + def test_pdf_on_disk_analyzed_directly(self, analyzer, mocker, tmp_path): + """When PDF exists on disk, it is analyzed directly without download.""" + patent_id = "US-11234567-B2" + + # Create the patents dir and PDF file + patents_dir = tmp_path / "patents" + patents_dir.mkdir() + pdf_path = patents_dir / f"{patent_id}.pdf" + pdf_path.write_bytes(b"fake PDF content") + + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_parse.return_value = {"abstract": "test", "claims": "test claims"} + mock_minimize.return_value = "minimized content" + analyzer.llm_analyzer.analyze_patent_content.return_value = "Good patent." + + # Change cwd so patents/{patent_id}.pdf resolves to our tmp_path + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + result = analyzer.analyze_single_patent(patent_id, "TestCo") + finally: + os.chdir(original_cwd) + + assert result == "Good patent." + # DB cache should not have been queried since file existed + analyzer.db.get_cached_patent.assert_not_called() + + def test_auto_download_from_cached_link(self, analyzer, mocker, tmp_path): + """When PDF is not on disk but link is cached, auto-download occurs.""" + patent_id = "US-99887766-A1" + + # No patents dir exists (PDF not on disk) + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + downloaded_patent = Patent(patent_id=patent_id, pdf_link="https://example.com/patent.pdf") + downloaded_patent.pdf_path = f"patents/{patent_id}.pdf" + mock_save.return_value = downloaded_patent + + # Cached patent has a PDF link + analyzer.db.get_cached_patent.return_value = { + "patent_id": patent_id, + "pdf_link": "https://example.com/patent.pdf", + } + + # Mock the rest of the analysis pipeline + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_parse.return_value = {"abstract": "test abstract"} + mock_minimize.return_value = "minimized content" + analyzer.llm_analyzer.analyze_patent_content.return_value = "Strong innovation." + + # Change cwd so patents/{patent_id}.pdf does NOT exist + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + result = analyzer.analyze_single_patent(patent_id, "DownloadCo") + finally: + os.chdir(original_cwd) + + assert result == "Strong innovation." + analyzer.db.get_cached_patent.assert_called_once_with(patent_id) + mock_save.assert_called_once() + # Verify the Patent passed to save_patents has the correct ID and link + saved_patent = mock_save.call_args[0][0] + assert saved_patent.patent_id == patent_id + assert saved_patent.pdf_link == "https://example.com/patent.pdf" + + def test_no_cached_link_raises_file_not_found(self, analyzer, mocker, tmp_path): + """When PDF is not on disk and no cached link, FileNotFoundError raised.""" + patent_id = "US-00000000-X1" + + analyzer.db.get_cached_patent.return_value = None + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + with pytest.raises(FileNotFoundError, match="no download link is cached"): + analyzer.analyze_single_patent(patent_id, "MissingCo") + finally: + os.chdir(original_cwd) + + def test_cached_patent_without_pdf_link_raises(self, analyzer, mocker, tmp_path): + """When cached patent exists but has no pdf_link, FileNotFoundError raised.""" + patent_id = "US-11111111-B1" + + analyzer.db.get_cached_patent.return_value = { + "patent_id": patent_id, + "pdf_link": None, + } + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + with pytest.raises(FileNotFoundError, match="no download link is cached"): + analyzer.analyze_single_patent(patent_id, "NoPDFCo") + finally: + os.chdir(original_cwd) + + def test_analysis_exception_returns_error_message(self, analyzer, mocker, tmp_path): + """When analysis pipeline fails, returns error string instead of raising.""" + patent_id = "US-22222222-A2" + + # Create the PDF on disk so it skips download + patents_dir = tmp_path / "patents" + patents_dir.mkdir() + (patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF") + + # Parse fails + mocker.patch( + "SPARC.analyzer.SERP.parse_patent_pdf", + side_effect=ValueError("Corrupt PDF"), + ) + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + result = analyzer.analyze_single_patent(patent_id, "ErrorCo") + finally: + os.chdir(original_cwd) + + assert "Failed to analyze patent" in result + assert "Corrupt PDF" in result + + def test_model_override_passed_to_llm(self, analyzer, mocker, tmp_path): + """The model parameter is forwarded to the LLM analyzer.""" + patent_id = "US-33333333-B2" + + patents_dir = tmp_path / "patents" + patents_dir.mkdir() + (patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF") + + mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf", return_value={"abstract": "test"}) + mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm", return_value="content") + analyzer.llm_analyzer.analyze_patent_content.return_value = "Analysis result." + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + result = analyzer.analyze_single_patent( + patent_id, "ModelCo", model="openai/gpt-4o" + ) + finally: + os.chdir(original_cwd) + + assert result == "Analysis result." + analyzer.llm_analyzer.analyze_patent_content.assert_called_once_with( + patent_content="content", + company_name="ModelCo", + model="openai/gpt-4o", + ) + + def test_file_not_found_during_parse_re_raised(self, analyzer, mocker, tmp_path): + """FileNotFoundError during parsing is re-raised, not caught.""" + patent_id = "US-44444444-C1" + + patents_dir = tmp_path / "patents" + patents_dir.mkdir() + (patents_dir / f"{patent_id}.pdf").write_bytes(b"fake PDF") + + mocker.patch( + "SPARC.analyzer.SERP.parse_patent_pdf", + side_effect=FileNotFoundError("PDF file vanished"), + ) + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + with pytest.raises(FileNotFoundError, match="PDF file vanished"): + analyzer.analyze_single_patent(patent_id, "VanishCo") + finally: + os.chdir(original_cwd) diff --git a/tests/test_auth.py b/tests/test_auth.py index de79259..983c44b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,13 +1,29 @@ -"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access.""" +"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access. -from datetime import datetime, timezone +Covers all five scenarios required by issue #1624: +1. Registration (POST /auth/register) +2. Login (POST /auth/login) +3. Protected route access (GET /auth/me) -- valid, missing, expired, wrong-type tokens +4. Token refresh (POST /auth/refresh) +5. Admin-only endpoints (GET /admin/users, PATCH role, DELETE user) + +All tests use mocked DB fixtures and require no live database. +""" + +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, patch +import jwt as pyjwt import pytest from fastapi.testclient import TestClient from SPARC.api import app -from SPARC.auth import create_access_token, create_refresh_token +from SPARC.auth import ( + JWT_ALGORITHM, + JWT_SECRET, + create_access_token, + create_refresh_token, +) @pytest.fixture @@ -171,12 +187,6 @@ class TestGetMe: def test_expired_token_returns_401(self, client, mock_db): """An expired token should return 401.""" - # Create a token that has already expired - from datetime import timedelta - - import jwt as pyjwt - from SPARC.auth import JWT_ALGORITHM, JWT_SECRET - payload = { "sub": "1", "email": "user@test.com", @@ -300,3 +310,193 @@ class TestAdminUsers: assert response.status_code == 400 assert "own role" in response.json()["detail"].lower() + + def test_role_change_nonexistent_user_returns_404(self, client, mock_db): + """Changing role for a user that does not exist should return 404.""" + admin = _make_admin_user() + mock_db.get_user_by_id.return_value = admin + mock_db.update_user_role.return_value = None + + response = client.patch( + "/admin/users/999/role", + json={"role": "admin"}, + headers=_auth_header(admin), + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_regular_user_cannot_change_role(self, client, mock_db): + """Non-admin user should receive 403 when trying to change roles.""" + user = _make_regular_user() + mock_db.get_user_by_id.return_value = user + + response = client.patch( + "/admin/users/1/role", + json={"role": "admin"}, + headers=_auth_header(user), + ) + + assert response.status_code == 403 + + +class TestAdminDeleteUser: + """DELETE /admin/users/{user_id}""" + + def test_admin_can_delete_user(self, client, mock_db): + """Admin should be able to delete another user.""" + admin = _make_admin_user() + mock_db.get_user_by_id.return_value = admin + mock_db.delete_user.return_value = True + + response = client.delete( + "/admin/users/2", + headers=_auth_header(admin), + ) + + assert response.status_code == 200 + assert "deleted" in response.json()["message"].lower() + mock_db.delete_user.assert_called_once_with(2) + + def test_admin_cannot_delete_self(self, client, mock_db): + """Admin should not be able to delete themselves.""" + admin = _make_admin_user() + mock_db.get_user_by_id.return_value = admin + + response = client.delete( + "/admin/users/1", + headers=_auth_header(admin), + ) + + assert response.status_code == 400 + assert "yourself" in response.json()["detail"].lower() + + def test_delete_nonexistent_user_returns_404(self, client, mock_db): + """Deleting a user that does not exist should return 404.""" + admin = _make_admin_user() + mock_db.get_user_by_id.return_value = admin + mock_db.delete_user.return_value = False + + response = client.delete( + "/admin/users/999", + headers=_auth_header(admin), + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_regular_user_cannot_delete_user(self, client, mock_db): + """Non-admin user should receive 403 when trying to delete users.""" + user = _make_regular_user() + mock_db.get_user_by_id.return_value = user + + response = client.delete( + "/admin/users/1", + headers=_auth_header(user), + ) + + assert response.status_code == 403 + + def test_no_token_cannot_delete_user(self, client): + """Missing token should be rejected for delete endpoint.""" + response = client.delete("/admin/users/1") + assert response.status_code in (401, 403) + + +class TestEdgeCases: + """Additional edge-case tests for auth robustness.""" + + def test_register_invalid_email_returns_422(self, client, mock_db): + """Registration with an invalid email format should return 422.""" + response = client.post( + "/auth/register", + json={"email": "not-an-email", "password": "securepass123"}, + ) + + assert response.status_code == 422 + + def test_register_short_password_returns_422(self, client, mock_db): + """Registration with a password shorter than 8 chars should return 422.""" + response = client.post( + "/auth/register", + json={"email": "user@test.com", "password": "short"}, + ) + + assert response.status_code == 422 + + def test_register_missing_fields_returns_422(self, client, mock_db): + """Registration with missing fields should return 422.""" + response = client.post("/auth/register", json={}) + assert response.status_code == 422 + + def test_login_missing_fields_returns_422(self, client, mock_db): + """Login with missing fields should return 422.""" + response = client.post("/auth/login", json={"email": "user@test.com"}) + assert response.status_code == 422 + + def test_malformed_token_returns_401(self, client, mock_db): + """A completely malformed token string should return 401.""" + response = client.get( + "/auth/me", + headers={"Authorization": "Bearer not.a.valid.jwt.token"}, + ) + assert response.status_code == 401 + + def test_token_with_wrong_secret_returns_401(self, client, mock_db): + """A token signed with a different secret should return 401.""" + payload = { + "sub": "1", + "email": "user@test.com", + "role": "user", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + "type": "access", + } + wrong_secret_token = pyjwt.encode(payload, "wrong-secret", algorithm=JWT_ALGORITHM) + + response = client.get( + "/auth/me", + headers={"Authorization": f"Bearer {wrong_secret_token}"}, + ) + assert response.status_code == 401 + + def test_token_for_deleted_user_returns_401(self, client, mock_db): + """A valid token for a user no longer in the DB should return 401.""" + user = _make_regular_user() + mock_db.get_user_by_id.return_value = None # user was deleted + + response = client.get("/auth/me", headers=_auth_header(user)) + assert response.status_code == 401 + + def test_refresh_for_deleted_user_returns_401(self, client, mock_db): + """Refreshing a token for a deleted user should return 401.""" + user = _make_regular_user() + mock_db.get_user_by_id.return_value = None + refresh = create_refresh_token(user["id"], user["email"], user["role"]) + + response = client.post( + "/auth/refresh", json={"refresh_token": refresh} + ) + assert response.status_code == 401 + + def test_login_returns_decodable_tokens(self, client, mock_db): + """Tokens returned by login should be decodable and contain expected claims.""" + user = _make_regular_user() + mock_db.authenticate_user.return_value = user + + response = client.post( + "/auth/login", + json={"email": "user@test.com", "password": "correctpassword"}, + ) + + data = response.json() + access_payload = pyjwt.decode( + data["access_token"], JWT_SECRET, algorithms=[JWT_ALGORITHM] + ) + assert access_payload["sub"] == str(user["id"]) + assert access_payload["email"] == user["email"] + assert access_payload["type"] == "access" + + refresh_payload = pyjwt.decode( + data["refresh_token"], JWT_SECRET, algorithms=[JWT_ALGORITHM] + ) + assert refresh_payload["type"] == "refresh" diff --git a/tests/test_company_name_validation.py b/tests/test_company_name_validation.py new file mode 100644 index 0000000..3e6855f --- /dev/null +++ b/tests/test_company_name_validation.py @@ -0,0 +1,157 @@ +"""Tests for company name input validation on analysis endpoints.""" + +from datetime import datetime +from unittest.mock import Mock + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.types import CompanyAnalysisResult + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_analyzer(mocker): + """Mock the global analyzer so valid requests succeed.""" + mock = Mock() + mock._analyze_company_safe.return_value = CompanyAnalysisResult( + company_name="nvidia", + analysis="Test analysis", + patent_count=1, + success=True, + timestamp=datetime.now(), + ) + mocker.patch("SPARC.api._analyzer", mock) + return mock + + +class TestCompanyNameValidation: + """Test that company names are validated on analysis endpoints.""" + + # --- Too short --- + + def test_single_char_rejected(self, client, mock_analyzer): + """A one-character company name should be rejected.""" + response = client.get("/analyze/X") + assert response.status_code == 422 + + # --- Too long --- + + def test_over_100_chars_rejected(self, client, mock_analyzer): + """A company name longer than 100 characters should be rejected.""" + long_name = "A" * 101 + response = client.get(f"/analyze/{long_name}") + assert response.status_code == 422 + + # --- Special characters --- + + @pytest.mark.parametrize( + "bad_name", + [ + "nvidia!", + "intel@corp", + "test#company", + "foo$bar", + "a%b", + "x^y", + "semi;colon", + "drop'table", + 'say"hello', + "path/traversal", + "back\\slash", + "pipe|char", + "star*glob", + "question?mark", + "