ci: enable ruff linting and pytest in CI pipeline #1568

Merged
AI-Manager merged 14 commits from feature/1559-1560-enable-ci-linting-and-tests into main 2026-04-19 23:08:09 +00:00
49 changed files with 10324 additions and 448 deletions
Showing only changes of commit 7e66d0e7e0 - Show all commits
+63 -9
View File
@@ -1,21 +1,75 @@
# SPARC Configuration # SPARC Configuration
# ---- Application Environment ----
# Set to "production" or "staging" in deployed environments.
# The API will refuse to start with the default JWT secret unless APP_ENV=development.
APP_ENV=development
# ---- API Keys ----
# SerpAPI key for patent search # SerpAPI key for patent search
API_KEY=your_serpapi_key_here API_KEY=your_serpapi_key_here
# OpenRouter API key for LLM analysis # OpenRouter API key for LLM analysis
OPENROUTER_API_KEY=your_openrouter_key_here OPENROUTER_API_KEY=your_openrouter_key_here
# Database configuration # ---- Database ----
# All messages are stored in the database for persistence and caching
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc
# Cache configuration # PostgreSQL credentials (used by docker-compose)
# When USE_CACHE=true: check database for cached responses before making API calls POSTGRES_USER=postgres
# When USE_CACHE=false: always make fresh API calls (still stores results in database) POSTGRES_PASSWORD=change-me-to-a-secure-password
# Default: true POSTGRES_DB=sparc
USE_CACHE=true
# JWT Secret for authentication # Full database URL (must match the credentials above)
DATABASE_URL=postgresql://postgres:change-me-to-a-secure-password@localhost:5432/sparc
# ---- Authentication ----
# JWT Secret for signing tokens
# IMPORTANT: Change this to a secure random string in production # IMPORTANT: Change this to a secure random string in production
JWT_SECRET=your-secure-jwt-secret-change-in-production JWT_SECRET=your-secure-jwt-secret-change-in-production
# ---- CORS ----
# Comma-separated list of allowed origins for CORS
# Defaults to http://localhost:3000,http://localhost:5173 when unset
# CORS_ORIGINS=https://sparc.example.com,https://app.example.com
# ---- Storage ----
# Backend for patent PDF storage: "local" (default) or "s3"
STORAGE_BACKEND=local
# S3/MinIO settings (only used when STORAGE_BACKEND=s3)
# S3_BUCKET=sparc-patents
# S3_ENDPOINT_URL=http://localhost:9000
# AWS_ACCESS_KEY_ID=minioadmin
# AWS_SECRET_ACCESS_KEY=minioadmin
# To start MinIO locally: docker compose --profile s3 up -d minio
# ---- LLM ----
# LLM model to use via OpenRouter
# Supported: anthropic/claude-3.5-sonnet, openai/gpt-4o, openai/gpt-4o-mini,
# google/gemini-pro-1.5, meta-llama/llama-3.1-70b-instruct
# MODEL=anthropic/claude-3.5-sonnet
# ---- Cache ----
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
USE_CACHE=true
# SERP API cache TTL in hours (how long cached search results are considered fresh)
# SERP_CACHE_TTL_HOURS=24
# ---- Logging ----
# Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL
# LOG_LEVEL=INFO
# ---- Webhooks ----
# Comma-separated list of webhook URLs for job completion and alert notifications
# Supports generic HTTP POST and Slack/Discord incoming webhooks
# WEBHOOK_URLS=https://hooks.slack.com/services/XXX,https://example.com/webhook
+51
View File
@@ -9,7 +9,57 @@ on:
workflow_dispatch: workflow_dispatch:
jobs: jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Install system dependencies
shell: sh
run: |
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
- name: Checkout code
shell: sh
run: |
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
git checkout ${{ gitea.sha }}
- name: Install Python dependencies
shell: sh
run: |
pip3 install --break-system-packages -r requirements.txt ruff
- name: Run ruff linter
shell: sh
run: |
ruff check SPARC/ tests/
- name: Install Node.js and check TypeScript types
shell: sh
run: |
apk add --no-cache nodejs npm
cd frontend
npm ci
npm run generate:local
if ! git diff --quiet src/api/schema.d.ts; then
echo "ERROR: src/api/schema.d.ts is out of date. Run 'npm run generate:local' and commit the result."
git diff src/api/schema.d.ts
exit 1
fi
npx tsc --noEmit
- name: Run pytest
shell: sh
env:
DATABASE_URL: "sqlite://"
API_KEY: "test-key"
OPENROUTER_API_KEY: "test-key"
JWT_SECRET: "test-secret-for-ci"
APP_ENV: "development"
run: |
python3 -m pytest tests/ -v --tb=short -x
build-api: build-api:
needs: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install dependencies - name: Install dependencies
@@ -81,6 +131,7 @@ jobs:
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}" echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
build-frontend: build-frontend:
needs: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install dependencies - name: Install dependencies
+67
View File
@@ -0,0 +1,67 @@
name: Test and Lint
on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Install system dependencies
shell: sh
run: |
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
- name: Checkout code
shell: sh
run: |
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
git checkout ${{ gitea.sha }}
- name: Install Python dependencies
shell: sh
run: |
pip3 install --break-system-packages -r requirements.txt ruff
- name: Run ruff linter
shell: sh
run: |
ruff check SPARC/ tests/
- name: Install Node.js and frontend dependencies
shell: sh
run: |
apk add --no-cache nodejs npm
cd frontend && npm ci
- name: Verify generated API types are up to date
shell: sh
run: |
cd frontend && npm run generate:local
if ! git diff --quiet src/api/schema.d.ts; then
echo "ERROR: src/api/schema.d.ts is out of date. Run 'npm run generate:local' and commit the result."
git diff src/api/schema.d.ts
exit 1
fi
- name: Run TypeScript type check
shell: sh
run: |
cd frontend && npx tsc --noEmit
- name: Run pytest
shell: sh
env:
DATABASE_URL: "sqlite://"
API_KEY: "test-key"
OPENROUTER_API_KEY: "test-key"
JWT_SECRET: "test-secret-for-ci"
APP_ENV: "development"
run: |
python3 -m pytest tests/ -v --tb=short -x
+15
View File
@@ -54,6 +54,21 @@ docker-compose up -d
# - API Docs: http://localhost:8000/docs # - API Docs: http://localhost:8000/docs
``` ```
#### Patent PDF Storage
The API stores downloaded patent PDFs in a `patents/` directory. In Docker,
this is mounted as a bind mount (`./patents:/app/patents`) so that PDFs persist
across container restarts.
If you deploy to a different environment, ensure the `patents/` directory is a
persistent volume. Without it, PDFs will be re-downloaded on every analysis.
```yaml
# docker-compose.yml excerpt
volumes:
- ./patents:/app/patents
```
### NixOS ### NixOS
```bash ```bash
+122
View File
@@ -0,0 +1,122 @@
# SPARC Roadmap
Semiconductor Patent & Analytics Report Core -- development priorities.
## Current State
SPARC is a patent analysis platform with a working end-to-end pipeline:
Python/FastAPI backend, React/TypeScript frontend, PostgreSQL for persistence
and caching, Docker Compose for local development, and Gitea Actions CI/CD for
image builds. Core features (patent retrieval via SerpAPI, PDF parsing, LLM
analysis via OpenRouter/Claude, batch processing, JWT authentication, analytics
dashboard) are all implemented and functional.
---
## P1 -- High Priority
These items address correctness, security, and reliability gaps that should be
resolved before broader production use.
### Security hardening
- **Rotate default JWT secret.** `auth.py` ships a fallback
`sparc-secret-key-change-in-production` that will be used if `JWT_SECRET` is
unset. Add a startup check that refuses to start with the default secret in
non-development environments.
- **CORS allow-origins are hardcoded.** `api.py` only permits
`localhost:3000` and `localhost:5173`. Make the allowed origins configurable
via environment variable so the dashboard works when deployed behind a real
domain.
- **Database credentials in docker-compose.yml.** The compose file embeds
`postgres:postgres` in plain text. Reference a `.env` file or Docker secrets
instead.
### Error handling and resilience
- **`get_db_client()` in `auth.py` creates a new `DatabaseClient` on every
call.** This bypasses the connection pool and can exhaust database
connections under load. Refactor to share a single pooled client.
- **`_jobs` dict is in-memory only.** Job state is lost on API restart. Persist
job status in PostgreSQL or Redis so async batch results survive restarts.
- **No rate limiting on auth endpoints.** `/auth/login` and `/auth/register`
are unprotected against brute-force or abuse. Add rate limiting middleware.
### Test coverage for auth and admin
- The existing API tests (`tests/test_api.py`) bypass authentication entirely.
Add tests that exercise the JWT flow: registration, login, protected-route
access, token refresh, and admin-only endpoints.
---
## P2 -- Medium Priority
Improvements to usability, performance, and developer experience.
### Backend
- **Add structured logging.** Replace `print()` calls throughout `analyzer.py`,
`serp_api.py`, and `llm.py` with Python `logging` so log levels and
formatting are consistent.
- **Make LLM model configurable.** `llm.py` hardcodes
`anthropic/claude-3.5-sonnet`. Accept a `MODEL` environment variable to allow
switching models without code changes.
- **SERP cache TTL is hardcoded to 24 hours.** Expose `SERP_CACHE_TTL_HOURS`
as an environment variable in `config.py`.
- **Patent PDF storage.** PDFs are saved to a local `patents/` directory. For
containerized deployments, consider object storage (S3/MinIO) or at minimum
document the volume mount requirement more prominently.
- **`analyze_single_patent` assumes local file path.** The method constructs
`patents/{patent_id}.pdf` and reads from disk, but does not download the PDF
first. Either integrate the download step or document the prerequisite.
- **`Patent.patent_id` typed as `int` in `types.py` but used as `str`
everywhere.** Fix the type annotation to `str`.
### Frontend
- **No loading/error states on several pages.** The Batch and Analytics pages
would benefit from skeleton loaders and user-friendly error messages.
- **No dark mode.** Tailwind is configured but no dark variant is applied.
- **Missing `package-lock.json` or `pnpm-lock.yaml`.** The frontend has no
lockfile committed, leading to non-reproducible builds.
### CI/CD
- **No test stage in the Gitea Actions workflow.** `build.yaml` builds and
pushes images but never runs `pytest`. Add a test job that gates the build.
- **No linting or type checking.** Add `ruff` (Python) and `tsc --noEmit`
(TypeScript) to CI.
---
## P3 -- Nice to Have
Lower-urgency enhancements and future features.
- **Export analysis reports.** Allow users to download analysis results as PDF
or CSV from the dashboard.
- **Comparison view.** Side-by-side comparison of two companies' patent
portfolios.
- **Scheduled/recurring analysis.** Periodically re-analyze tracked companies
and alert on significant changes.
- **Webhook/notification support.** Send alerts (Slack, Discord, email) when
batch jobs complete or when a company's innovation score changes
significantly.
- **Multi-model support.** Let users choose between LLM providers per analysis
(e.g., GPT-4o, Gemini, Claude) and compare outputs.
- **Patent trend charts.** Visualize patent filing frequency and technology
category distribution over time in the Analytics page.
- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit
from cursor-based pagination for large result sets.
- **OpenAPI client generation.** Auto-generate the TypeScript API client from
the FastAPI OpenAPI spec to keep frontend types in sync.
---
## Infrastructure and Deployment
Kubernetes manifests, Helm charts, and cluster-level concerns (MetalLB,
storage, FluxCD sync) are tracked in the
[Talos](https://10.0.1.10/leeworks-agents/Talos) repository. File
infrastructure-related issues there, not here.
+3 -2
View File
@@ -1,3 +1,4 @@
from .types import Patents, Patent from .types import Patent as Patent
from .types import Patents as Patents
all = ["Patents", "Patent"] __all__ = ["Patents", "Patent"]
+64 -30
View File
@@ -5,14 +5,17 @@ to provide company performance estimation based on patent portfolios.
""" """
import hashlib import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable from typing import Callable
from SPARC import config from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from SPARC.serp_api import SERP
from SPARC.llm import LLMAnalyzer from SPARC.llm import LLMAnalyzer
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult from SPARC.serp_api import SERP
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
class CompanyAnalyzer: class CompanyAnalyzer:
@@ -30,7 +33,7 @@ class CompanyAnalyzer:
self.db.connect() self.db.connect()
self.db.initialize_schema() self.db.initialize_schema()
def analyze_company(self, company_name: str, patents: "Patents | None" = None) -> str: def analyze_company(self, company_name: str, patents: "Patents | None" = None, model: str | None = None) -> str:
"""Analyze a company's performance based on their patent portfolio. """Analyze a company's performance based on their patent portfolio.
This is the main entry point that orchestrates the full pipeline: This is the main entry point that orchestrates the full pipeline:
@@ -43,6 +46,7 @@ class CompanyAnalyzer:
Args: Args:
company_name: Name of the company to analyze company_name: Name of the company to analyze
patents: Optional pre-fetched Patents result to avoid duplicate API calls patents: Optional pre-fetched Patents result to avoid duplicate API calls
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
Comprehensive analysis of company's innovation and performance outlook Comprehensive analysis of company's innovation and performance outlook
@@ -52,13 +56,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash) cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None: if cached_ids is not None:
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)") logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
patents = Patents(patents=[ patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="") Patent(patent_id=pid, pdf_link="")
for pid in cached_ids for pid in cached_ids
]) ])
else: else:
print(f"Retrieving patents for {company_name}...") logger.info("Retrieving patents for %s...", company_name)
patents = SERP.query(company_name) patents = SERP.query(company_name)
# Cache the SERP results # Cache the SERP results
if patents.patents: if patents.patents:
@@ -66,12 +70,13 @@ class CompanyAnalyzer:
company_name=company_name, company_name=company_name,
query_hash=query_hash, query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents], patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
) )
if not patents.patents: if not patents.patents:
return f"No patents found for {company_name}" return f"No patents found for {company_name}"
print(f"Found {len(patents.patents)} patents. Processing...") logger.info("Found %d patents. Processing...", len(patents.patents))
# Download, parse, and minimize patents in parallel # Download, parse, and minimize patents in parallel
processed_patents = [] processed_patents = []
@@ -87,48 +92,74 @@ class CompanyAnalyzer:
if result: if result:
processed_patents.append(result) processed_patents.append(result)
except Exception as e: except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}") logger.warning("Failed to process %s: %s", patent.patent_id, e)
if not processed_patents: if not processed_patents:
return f"Failed to process any patents for {company_name}" return f"Failed to process any patents for {company_name}"
print(f"Analyzing portfolio with LLM...") logger.info("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio( analysis = self.llm_analyzer.analyze_patent_portfolio(
patents_data=processed_patents, company_name=company_name patents_data=processed_patents, company_name=company_name, model=model
) )
return analysis return analysis
def analyze_single_patent(self, patent_id: str, company_name: str) -> str: def analyze_single_patent(self, patent_id: str, company_name: str, model: str | None = None) -> str:
"""Analyze a single patent by ID. """Analyze a single patent by ID.
Useful for focused analysis of specific innovations. If the patent PDF is not already on disk, this method attempts to
download it automatically by looking up the PDF link in the database
cache. If the link is not cached either, a ``FileNotFoundError`` is
raised with instructions on how to obtain the PDF.
Args: Args:
patent_id: Publication ID of the patent patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
company_name: Name of the company (for context) company_name: Name of the company (for context)
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
Analysis of the specific patent's innovation quality Analysis of the specific patent's innovation quality
Raises:
FileNotFoundError: If the patent PDF cannot be found or downloaded.
""" """
# Note: This simplified version assumes the patent PDF is already downloaded import os
# A more complete implementation would support direct patent ID lookup logger.info("Analyzing patent %s for %s...", patent_id, company_name)
print(f"Analyzing patent {patent_id} for {company_name}...")
patent_path = f"patents/{patent_id}.pdf" patent_path = f"patents/{patent_id}.pdf"
if not os.path.exists(patent_path):
# Attempt to download the PDF automatically from cached metadata
cached = self.db.get_cached_patent(patent_id)
pdf_link = cached.get("pdf_link") if cached else None
if pdf_link:
logger.info("PDF not on disk; downloading %s from cached link", patent_id)
patent = SERP.save_patents(
Patent(patent_id=patent_id, pdf_link=pdf_link)
)
patent_path = patent.pdf_path
else:
raise FileNotFoundError(
f"Patent PDF not found at '{patent_path}' and no download link is "
f"cached for '{patent_id}'. Run a company analysis first to populate "
f"the cache, or call SERP.save_patents() with the patent's PDF link."
)
try: try:
sections = SERP.parse_patent_pdf(patent_path) sections = SERP.parse_patent_pdf(patent_path)
minimized_content = SERP.minimize_patent_for_llm(sections) minimized_content = SERP.minimize_patent_for_llm(sections)
analysis = self.llm_analyzer.analyze_patent_content( analysis = self.llm_analyzer.analyze_patent_content(
patent_content=minimized_content, company_name=company_name patent_content=minimized_content, company_name=company_name, model=model
) )
return analysis return analysis
except FileNotFoundError:
raise
except Exception as e: except Exception as e:
return f"Failed to analyze patent {patent_id}: {e}" return f"Failed to analyze patent {patent_id}: {e}"
@@ -169,21 +200,22 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content} return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e: except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}") logger.warning("Failed to process %s: %s", patent.patent_id, e)
return None return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: def _analyze_company_safe(self, company_name: str, model: str | None = None) -> CompanyAnalysisResult:
"""Internal wrapper that catches exceptions and returns structured result. """Internal wrapper that catches exceptions and returns structured result.
Args: Args:
company_name: Name of the company to analyze company_name: Name of the company to analyze
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
CompanyAnalysisResult with success/failure status CompanyAnalysisResult with success/failure status
""" """
try: try:
# Delegate to analyze_company which handles SERP/patent caching # Delegate to analyze_company which handles SERP/patent caching
analysis = self.analyze_company(company_name) analysis = self.analyze_company(company_name, model=model)
# Determine patent count from cached SERP query # Determine patent count from cached SERP query
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
@@ -223,6 +255,7 @@ class CompanyAnalyzer:
companies: list[str], companies: list[str],
max_workers: int = 3, max_workers: int = 3,
progress_callback: Callable[[str, int, int], None] | None = None, progress_callback: Callable[[str, int, int], None] | None = None,
model: str | None = None,
) -> BatchAnalysisResult: ) -> BatchAnalysisResult:
"""Analyze multiple companies' patent portfolios in batch. """Analyze multiple companies' patent portfolios in batch.
@@ -233,6 +266,7 @@ class CompanyAnalyzer:
companies: List of company names to analyze companies: List of company names to analyze
max_workers: Maximum concurrent analyses (default 3 to avoid rate limits) max_workers: Maximum concurrent analyses (default 3 to avoid rate limits)
progress_callback: Optional callback(company_name, completed, total) progress_callback: Optional callback(company_name, completed, total)
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
BatchAnalysisResult containing all individual results and summary stats BatchAnalysisResult containing all individual results and summary stats
@@ -240,11 +274,11 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
print(f"Starting batch analysis of {total} companies...") logger.info("Starting batch analysis of %d companies...", total)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = { future_to_company = {
executor.submit(self._analyze_company_safe, company): company executor.submit(self._analyze_company_safe, company, model): company
for company in companies for company in companies
} }
@@ -257,8 +291,8 @@ class CompanyAnalyzer:
result = future.result() result = future.result()
results.append(result) results.append(result)
status = "" if result.success else "" status = "OK" if result.success else "FAIL"
print(f"[{completed}/{total}] {status} {company}") logger.info("[%d/%d] %s %s", completed, total, status, company)
if progress_callback: if progress_callback:
progress_callback(company, completed, total) progress_callback(company, completed, total)
@@ -273,12 +307,12 @@ class CompanyAnalyzer:
error=str(e), error=str(e),
) )
) )
print(f"[{completed}/{total}] ✗ {company}: {e}") logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed") logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
@@ -304,20 +338,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
print(f"Starting sequential analysis of {total} companies...") logger.info("Starting sequential analysis of %d companies...", total)
for idx, company in enumerate(companies, 1): for idx, company in enumerate(companies, 1):
print(f"\n[{idx}/{total}] Analyzing {company}...") logger.info("[%d/%d] Analyzing %s...", idx, total, company)
result = self._analyze_company_safe(company) result = self._analyze_company_safe(company)
results.append(result) results.append(result)
status = "" if result.success else "" status = "OK" if result.success else "FAIL"
print(f"[{idx}/{total}] {status} {company}") logger.info("[%d/%d] %s %s", idx, total, status, company)
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed") logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
+595 -44
View File
@@ -7,20 +7,27 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import Annotated, List from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config from SPARC import config
from SPARC.analyzer import CompanyAnalyzer from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import ( from SPARC.auth import (
TokenResponse, TokenResponse,
UserResponse, UserResponse,
check_jwt_secret,
close_db_client,
create_tokens, create_tokens,
decode_token, decode_token,
get_current_admin, get_current_admin,
get_current_user, get_current_user,
get_db_client, get_db_client,
init_db_client,
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -34,6 +41,7 @@ class CompanyAnalysisResponse(BaseModel):
patent_count: int patent_count: int
success: bool success: bool
error: str | None = None error: str | None = None
model: str | None = None
timestamp: datetime timestamp: datetime
@@ -47,6 +55,15 @@ class BatchAnalysisResponse(BaseModel):
timestamp: datetime timestamp: datetime
class CompanyAnalysisRequest(BaseModel):
"""Request model for single company analysis with optional model selection."""
model: str | None = Field(
default=None,
description="LLM model to use (e.g. 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o'). Defaults to server config.",
)
class BatchAnalysisRequest(BaseModel): class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis.""" """Request model for batch company analysis."""
@@ -56,6 +73,10 @@ class BatchAnalysisRequest(BaseModel):
max_workers: int = Field( max_workers: int = Field(
default=3, ge=1, le=5, description="Max concurrent analyses" default=3, ge=1, le=5, description="Max concurrent analyses"
) )
model: str | None = Field(
default=None,
description="LLM model to use for all analyses in this batch. Defaults to server config.",
)
class JobStatus(BaseModel): class JobStatus(BaseModel):
@@ -70,6 +91,13 @@ class JobStatus(BaseModel):
error: str | None = None error: str | None = None
class PaginatedJobsResponse(BaseModel):
"""Paginated response for job listings."""
items: list["JobStatus"]
next_cursor: str | None = None
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
@@ -114,8 +142,7 @@ class AnalyticsResponse(BaseModel):
period_days: int period_days: int
# In-memory job storage (for demo; production would use Redis/DB) # Job counter for generating unique IDs (the actual state is in PostgreSQL)
_jobs: dict[str, JobStatus] = {}
_job_counter = 0 _job_counter = 0
@@ -127,6 +154,7 @@ def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
patent_count=result.patent_count, patent_count=result.patent_count,
success=result.success, success=result.success,
error=result.error, error=result.error,
model=result.model,
timestamp=result.timestamp, timestamp=result.timestamp,
) )
@@ -148,12 +176,28 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Initialize resources on startup.""" """Initialize resources on startup, clean up on shutdown."""
global _analyzer global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer() _analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
# Start scheduled analysis if tracked companies are configured
from SPARC.scheduler import start_scheduler
start_scheduler()
yield yield
# Cleanup if needed # Cleanup
_analyzer = None _analyzer = None
close_db_client()
app = FastAPI( app = FastAPI(
@@ -164,10 +208,26 @@ app = FastAPI(
root_path=config.root_path, root_path=config.root_path,
) )
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend # Add CORS middleware for React frontend
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"], allow_origins=config.cors_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@@ -178,7 +238,8 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"]) @app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
async def register(request: RegisterRequest): @limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user. """Register a new user.
The first registered user automatically becomes an admin. The first registered user automatically becomes an admin.
@@ -190,8 +251,8 @@ async def register(request: RegisterRequest):
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
user = db.create_user( user = db.create_user(
email=request.email, email=body.email,
password=request.password, password=body.password,
role=role, role=role,
) )
@@ -210,11 +271,12 @@ async def register(request: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"]) @app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
async def login(request: LoginRequest): @limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens.""" """Authenticate user and return JWT tokens."""
db = get_db_client() db = get_db_client()
user = db.authenticate_user(request.email, request.password) user = db.authenticate_user(body.email, body.password)
if not user: if not user:
raise HTTPException( raise HTTPException(
@@ -332,6 +394,60 @@ async def delete_user(
return {"message": "User deleted"} return {"message": "User deleted"}
# ============== Tracked Companies Endpoints ==============
class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255)
@app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin),
):
"""List all tracked companies (admin only)."""
db = get_db_client()
return db.list_tracked_companies()
@app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company(
request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin),
):
"""Add a company to the tracked list (admin only)."""
db = get_db_client()
result = db.add_tracked_company(request.company_name)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company(
company_name: str,
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
db = get_db_client()
removed = db.remove_tracked_company(company_name)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/alerts", tags=["Admin"])
async def list_alerts(
limit: int = Query(default=50, ge=1, le=200),
_: UserResponse = Depends(get_current_admin),
):
"""List recent alerts from scheduled analysis (admin only)."""
db = get_db_client()
return db.list_alerts(limit=limit)
# ============== Analytics Endpoint ============== # ============== Analytics Endpoint ==============
@@ -352,6 +468,331 @@ async def get_analytics(
) )
# ============== Model Selection Endpoints ==============
# Supported models via OpenRouter
SUPPORTED_MODELS = [
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
{"id": "openai/gpt-4o", "name": "GPT-4o", "provider": "OpenAI"},
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "provider": "OpenAI"},
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
]
_SUPPORTED_MODEL_IDS = {m["id"] for m in SUPPORTED_MODELS}
def _validate_model(model: str | None) -> None:
"""Raise HTTP 400 if *model* is not in the supported allow-list."""
if model is not None and model not in _SUPPORTED_MODEL_IDS:
raise HTTPException(
status_code=400,
detail=(
f"Unsupported model '{model}'. "
f"Supported models: {', '.join(sorted(_SUPPORTED_MODEL_IDS))}"
),
)
@app.get("/models", tags=["System"])
async def list_models():
"""List supported LLM models for analysis.
Returns the available models that can be passed as the `model` field
in analysis requests. The default model is determined by the `MODEL`
environment variable on the server.
"""
return {
"models": SUPPORTED_MODELS,
"default": config.model,
}
@app.get("/analytics/trends", tags=["Analytics"])
async def get_analytics_trends(
days: int = Query(default=90, ge=7, le=365),
_: UserResponse = Depends(get_current_user),
):
"""Get trend data for patent analysis over time.
Returns two datasets:
- ``by_month``: analysis count per company per month
- ``by_type_over_time``: analysis type distribution per month
Args:
days: Number of days to look back (default 90)
Returns:
Trend data suitable for time-series and distribution charts
"""
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
# Analyses per company per month
cur.execute(
"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
company_name,
COUNT(*) AS count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
AND company_name IS NOT NULL
GROUP BY month, company_name
ORDER BY month
""",
(days,),
)
by_month_rows = cur.fetchall()
# Analysis type distribution per month
cur.execute(
"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
analysis_type,
COUNT(*) AS count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
GROUP BY month, analysis_type
ORDER BY month
""",
(days,),
)
by_type_rows = cur.fetchall()
by_month = [
{"month": row[0], "company_name": row[1], "count": row[2]}
for row in by_month_rows
]
by_type_over_time = [
{"month": row[0], "analysis_type": row[1], "count": row[2]}
for row in by_type_rows
]
return {
"by_month": by_month,
"by_type_over_time": by_type_over_time,
"period_days": days,
}
# ============== Export Endpoints ==============
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: str,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows:
writer.writerow(row)
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: str,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io
import textwrap
from reportlab.lib import colors
from reportlab.lib.pagesizes import letter
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch
from reportlab.platypus import (
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=letter,
rightMargin=0.75 * inch,
leftMargin=0.75 * inch,
topMargin=0.75 * inch,
bottomMargin=0.75 * inch,
)
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
"CustomTitle",
parent=styles["Title"],
fontSize=20,
spaceAfter=6,
)
subtitle_style = ParagraphStyle(
"Subtitle",
parent=styles["Normal"],
fontSize=11,
textColor=colors.grey,
spaceAfter=20,
)
heading_style = ParagraphStyle(
"SectionHeading",
parent=styles["Heading2"],
fontSize=13,
spaceBefore=16,
spaceAfter=8,
textColor=colors.HexColor("#1a1a2e"),
)
body_style = ParagraphStyle(
"BodyText",
parent=styles["Normal"],
fontSize=9,
leading=13,
spaceAfter=10,
)
elements = []
# Title and date
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [
["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
["Models Used", ", ".join(sorted(set(r[2] for r in rows)))],
]
summary_table = Table(summary_data, colWidths=[2 * inch, 4.5 * inch])
summary_table.setStyle(
TableStyle(
[
("BACKGROUND", (0, 0), (0, -1), colors.HexColor("#f0f0f5")),
("FONTNAME", (0, 0), (0, -1), "Helvetica-Bold"),
("FONTSIZE", (0, 0), (-1, -1), 9),
("PADDING", (0, 0), (-1, -1), 6),
("GRID", (0, 0), (-1, -1), 0.5, colors.HexColor("#cccccc")),
("VALIGN", (0, 0), (-1, -1), "TOP"),
]
)
)
elements.append(summary_table)
elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
elements.append(
Paragraph(f"Analysis {i}: {analysis_type} (via {model})", heading_style)
)
elements.append(
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
)
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = (
response.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"):
if line.strip():
elements.append(Paragraph(line, body_style))
else:
elements.append(Spacer(1, 4))
elements.append(Spacer(1, 10))
doc.build(elements)
buffer.seek(0)
safe_name = company_name.replace(" ", "_").lower()
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse(
iter([buffer.getvalue()]),
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
# ============== System Endpoints ============== # ============== System Endpoints ==============
@@ -372,6 +813,7 @@ async def health_check():
) )
async def analyze_company( async def analyze_company(
company_name: str, company_name: str,
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single company's patent portfolio. """Analyze a single company's patent portfolio.
@@ -381,17 +823,51 @@ async def analyze_company(
Args: Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel") company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns: Returns:
Analysis results including patent count, AI insights, and success status Analysis results including patent count, AI insights, and success status
""" """
_validate_model(model)
if not _analyzer: if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized") raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name) result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result) return _convert_result(result)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: str = Query(description="Company name for analysis context"),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.post( @app.post(
"/analyze/batch", "/analyze/batch",
response_model=BatchAnalysisResponse, response_model=BatchAnalysisResponse,
@@ -412,43 +888,98 @@ async def analyze_companies_batch(
Returns: Returns:
Batch results with individual company analyses and summary statistics Batch results with individual company analyses and summary statistics
""" """
_validate_model(request.model)
if not _analyzer: if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized") raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer.analyze_companies( result = _analyzer.analyze_companies(
companies=request.companies, companies=request.companies,
max_workers=request.max_workers, max_workers=request.max_workers,
model=request.model,
) )
return _convert_batch_result(result) return _convert_batch_result(result)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int): def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None):
"""Background task for batch analysis.""" """Background task for batch analysis."""
global _jobs, _analyzer import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer: if not _analyzer:
_jobs[job_id].status = "failed" db.update_job(job_id, status="failed", error="Analyzer not initialized")
_jobs[job_id].error = "Analyzer not initialized"
return return
_jobs[job_id].status = "running" db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int): def progress_callback(company: str, completed: int, total: int):
_jobs[job_id].completed_companies = completed db.update_job(
_jobs[job_id].progress = int((completed / total) * 100) job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try: try:
result = _analyzer.analyze_companies( result = _analyzer.analyze_companies(
companies=companies, companies=companies,
max_workers=max_workers, max_workers=max_workers,
progress_callback=progress_callback, progress_callback=progress_callback,
model=model,
)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
# Fire webhook notification
from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="completed",
total_companies=result.total_companies,
successful=result.successful,
failed=result.failed,
) )
_jobs[job_id].status = "completed"
_jobs[job_id].progress = 100
_jobs[job_id].result = _convert_batch_result(result)
except Exception as e: except Exception as e:
_jobs[job_id].status = "failed" db.update_job(job_id, status="failed", error=str(e))
_jobs[job_id].error = str(e) from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="failed",
total_companies=len(companies),
successful=0,
failed=len(companies),
)
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"]) @app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
@@ -468,24 +999,20 @@ async def analyze_companies_async(
Returns: Returns:
Job status with job_id for polling Job status with job_id for polling
""" """
_validate_model(request.model)
global _job_counter global _job_counter
_job_counter += 1 _job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
_jobs[job_id] = JobStatus( db = _get_job_db()
job_id=job_id, job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
status="pending",
progress=0,
total_companies=len(request.companies),
completed_companies=0,
)
background_tasks.add_task( background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers _run_batch_job, job_id, request.companies, request.max_workers, request.model
) )
return _jobs[job_id] return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
@@ -501,36 +1028,60 @@ async def get_job_status(
Returns: Returns:
Current job status including progress and results when complete Current job status including progress and results when complete
""" """
if job_id not in _jobs: db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found") raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _jobs[job_id] return _job_row_to_status(job_row)
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"]) @app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
async def list_jobs( async def list_jobs(
status: Annotated[ status: Annotated[
str | None, str | None,
Query(description="Filter by status: pending, running, completed, failed"), Query(description="Filter by status: pending, running, completed, failed"),
] = None, ] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 10, limit: Annotated[int, Query(ge=1, le=100)] = 10,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""List all analysis jobs. """List analysis jobs with cursor-based pagination.
Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page.
When ``next_cursor`` is ``null``, there are no more results.
Existing clients that use only ``limit`` (without ``cursor``) continue to
work without modification.
Args: Args:
status: Optional filter by job status status: Optional filter by job status
limit: Maximum number of jobs to return (default 10, max 100) limit: Maximum number of jobs to return (default 10, max 100)
cursor: Opaque pagination cursor from a previous response
Returns: Returns:
List of job statuses Paginated list of job statuses
""" """
jobs = list(_jobs.values()) db = _get_job_db()
# Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
if status: has_next = len(job_rows) > limit
jobs = [j for j in jobs if j.status == status] if has_next:
job_rows = job_rows[:limit]
# Return most recent first items = [_job_row_to_status(row) for row in job_rows]
jobs.sort(key=lambda j: j.job_id, reverse=True)
return jobs[:limit] next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
+44 -5
View File
@@ -13,11 +13,25 @@ from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
# JWT Configuration # JWT Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production") _DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7 REFRESH_TOKEN_EXPIRE_DAYS = 7
def check_jwt_secret() -> None:
"""Refuse to start with the default JWT secret in non-development environments.
Raises:
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
"""
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
raise RuntimeError(
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer() security = HTTPBearer()
@@ -132,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient: def get_db_client() -> DatabaseClient:
"""Get database client for auth operations.""" """Get the shared pooled database client for auth operations.
client = DatabaseClient(config.database_url)
client.connect() Returns the module-level singleton DatabaseClient. If not yet initialized
return client (e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user( async def get_current_user(
+36 -1
View File
@@ -2,11 +2,20 @@
Loads environment variables from .env file for API keys and other secrets. Loads environment variables from .env file for API keys and other secrets.
""" """
from dotenv import load_dotenv import logging
import os import os
from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search # SerpAPI key for patent search
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
@@ -30,6 +39,32 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90")) patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5")) patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/) # Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy # This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "") root_path = os.getenv("ROOT_PATH", "")
# Application environment: "development", "staging", or "production"
# Used for safety checks (e.g., refusing default JWT secret in production)
app_env = os.getenv("APP_ENV", "development")
# Storage backend: "local" (default) or "s3" for S3/MinIO object storage
storage_backend = os.getenv("STORAGE_BACKEND", "local")
s3_bucket = os.getenv("S3_BUCKET", "sparc-patents")
s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "")
s3_access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
s3_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
# CORS allowed origins (comma-separated)
# Defaults to localhost dev origins when unset
_cors_origins_raw = os.getenv("CORS_ORIGINS", "")
cors_origins: list[str] = (
[o.strip() for o in _cors_origins_raw.split(",") if o.strip()]
if _cors_origins_raw
else ["http://localhost:3000", "http://localhost:5173"]
)
+312 -47
View File
@@ -1,14 +1,15 @@
"""Database client for storing and retrieving LLM messages and user authentication.""" """Database client for storing and retrieving LLM messages and user authentication."""
import contextlib import contextlib
import psycopg2
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extras import RealDictCursor
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import json
import hashlib import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import bcrypt import bcrypt
import psycopg2
from psycopg2.extras import RealDictCursor
from psycopg2.pool import ThreadedConnectionPool
class DatabaseClient: class DatabaseClient:
@@ -171,6 +172,55 @@ class DatabaseClient:
ON serp_queries(query_hash) ON serp_queries(query_hash)
""") """)
# Create jobs table for persisting async batch job state
cursor.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id VARCHAR(128) PRIMARY KEY,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress INTEGER NOT NULL DEFAULT 0,
total_companies INTEGER NOT NULL DEFAULT 0,
completed_companies INTEGER NOT NULL DEFAULT 0,
result_json JSONB,
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_status
ON jobs(status)
""")
# Create tracked companies table for scheduled analysis
cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL,
last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create alerts table for significant changes
cursor.execute("""
CREATE TABLE IF NOT EXISTS alerts (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) NOT NULL,
alert_type VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
old_value NUMERIC,
new_value NUMERIC,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_alerts_company
ON alerts(company_name)
""")
self.conn.commit() self.conn.commit()
@staticmethod @staticmethod
@@ -201,8 +251,6 @@ class DatabaseClient:
Returns: Returns:
Cached message dict if found, None otherwise Cached message dict if found, None otherwise
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
query = """ query = """
@@ -225,7 +273,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1" query += " ORDER BY timestamp DESC LIMIT 1"
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params) cursor.execute(query, params)
result = cursor.fetchone() result = cursor.fetchone()
return dict(result) if result else None return dict(result) if result else None
@@ -256,11 +305,10 @@ class DatabaseClient:
Returns: Returns:
The ID of the inserted record The ID of the inserted record
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
with self.conn.cursor() as cursor: with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO llm_messages INSERT INTO llm_messages
@@ -282,7 +330,7 @@ class DatabaseClient:
) )
message_id = cursor.fetchone()[0] message_id = cursor.fetchone()[0]
self.conn.commit() conn.commit()
return message_id return message_id
@@ -304,8 +352,6 @@ class DatabaseClient:
Returns: Returns:
List of message dictionaries List of message dictionaries
""" """
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
@@ -320,7 +366,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset]) params.extend([limit, offset])
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params) cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -333,9 +380,8 @@ class DatabaseClient:
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages # Total messages
cursor.execute( cursor.execute(
""" """
@@ -462,6 +508,156 @@ class DatabaseClient:
) )
conn.commit() conn.commit()
# Job Persistence Methods
def create_job(
self,
job_id: str,
total_companies: int,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
Returns:
Job dict
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
RETURNING *
""",
(job_id, total_companies),
)
job = cursor.fetchone()
conn.commit()
return dict(job)
def update_job(
self,
job_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
completed_companies: Optional[int] = None,
result_json: Optional[str] = None,
error: Optional[str] = None,
) -> Optional[Dict]:
"""Update a job's state.
Only non-None fields are updated.
"""
updates = []
params = []
if status is not None:
updates.append("status = %s")
params.append(status)
if progress is not None:
updates.append("progress = %s")
params.append(progress)
if completed_companies is not None:
updates.append("completed_companies = %s")
params.append(completed_companies)
if result_json is not None:
updates.append("result_json = %s")
params.append(result_json)
if error is not None:
updates.append("error = %s")
params.append(error)
if not updates:
return self.get_job(job_id)
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(job_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
params,
)
job = cursor.fetchone()
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
job = cursor.fetchone()
return dict(job) if job else None
def list_jobs(
self,
status: Optional[str] = None,
limit: int = 10,
cursor: Optional[str] = None,
) -> List[Dict]:
"""List jobs with optional status filter and cursor-based pagination.
Args:
status: Optional status filter (pending, running, completed, failed).
limit: Maximum number of jobs to return.
cursor: Opaque cursor (``created_at|job_id``) from a previous
response. When provided, only jobs older than the cursor are
returned.
Returns:
List of job dicts ordered by created_at descending.
"""
conditions: list[str] = []
params: list = []
if status:
conditions.append("status = %s")
params.append(status)
if cursor:
try:
ts_str, cursor_job_id = cursor.rsplit("|", 1)
conditions.append("(created_at, job_id) < (%s, %s)")
params.extend([ts_str, cursor_job_id])
except ValueError:
pass # Ignore malformed cursors; return from start
query = "SELECT * FROM jobs"
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY created_at DESC, job_id DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, params)
return [dict(row) for row in cur.fetchall()]
def mark_stale_jobs_failed(self) -> int:
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
Called at startup to clean up jobs that were interrupted by a restart.
Returns:
Number of jobs marked as failed.
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
updated_at = CURRENT_TIMESTAMP
WHERE status IN ('running', 'pending')
"""
)
count = cursor.rowcount
conn.commit()
return count
# User Authentication Methods # User Authentication Methods
@staticmethod @staticmethod
@@ -505,12 +701,11 @@ class DatabaseClient:
Returns: Returns:
Created user dict or None if email exists Created user dict or None if email exists
""" """
self.connect()
password_hash = self.hash_password(password) password_hash = self.hash_password(password)
try: try:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO users (email, password_hash, role) INSERT INTO users (email, password_hash, role)
@@ -520,10 +715,9 @@ class DatabaseClient:
(email, password_hash, role), (email, password_hash, role),
) )
user = cursor.fetchone() user = cursor.fetchone()
self.conn.commit() conn.commit()
return dict(user) if user else None return dict(user) if user else None
except psycopg2.errors.UniqueViolation: except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]: def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -536,9 +730,8 @@ class DatabaseClient:
Returns: Returns:
User dict if authenticated, None otherwise User dict if authenticated, None otherwise
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT * FROM users WHERE email = %s", "SELECT * FROM users WHERE email = %s",
(email,), (email,),
@@ -563,9 +756,8 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s", "SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,), (user_id,),
@@ -582,9 +774,8 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s", "SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,), (email,),
@@ -602,9 +793,8 @@ class DatabaseClient:
Returns: Returns:
List of user dicts List of user dicts
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
SELECT id, email, role, created_at SELECT id, email, role, created_at
@@ -626,9 +816,8 @@ class DatabaseClient:
Returns: Returns:
Updated user dict or None Updated user dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
UPDATE users UPDATE users
@@ -639,7 +828,7 @@ class DatabaseClient:
(role, user_id), (role, user_id),
) )
user = cursor.fetchone() user = cursor.fetchone()
self.conn.commit() conn.commit()
return dict(user) if user else None return dict(user) if user else None
def delete_user(self, user_id: int) -> bool: def delete_user(self, user_id: int) -> bool:
@@ -651,12 +840,11 @@ class DatabaseClient:
Returns: Returns:
True if deleted True if deleted
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor() as cursor:
with self.conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
self.conn.commit() conn.commit()
return deleted return deleted
def get_user_count(self) -> int: def get_user_count(self) -> int:
@@ -665,8 +853,85 @@ class DatabaseClient:
Returns: Returns:
Number of users Number of users
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor() as cursor:
with self.conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users") cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0] return cursor.fetchone()[0]
# Tracked Companies Methods
def add_tracked_company(self, company_name: str) -> Optional[Dict]:
"""Add a company to the tracking list."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try:
cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *",
(company_name,),
)
row = cursor.fetchone()
conn.commit()
return dict(row) if row else None
except Exception:
conn.rollback()
return None
def remove_tracked_company(self, company_name: str) -> bool:
"""Remove a company from the tracking list."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
conn.commit()
return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]:
"""List all tracked companies."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name")
return [dict(row) for row in cursor.fetchall()]
def update_tracked_company(
self, company_name: str, patent_count: int
) -> None:
"""Update the last analysis stats for a tracked company."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""UPDATE tracked_companies
SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP
WHERE LOWER(company_name) = LOWER(%s)""",
(patent_count, company_name),
)
conn.commit()
def store_alert(
self,
company_name: str,
alert_type: str,
message: str,
old_value: float | None = None,
new_value: float | None = None,
) -> None:
"""Record an alert for a significant change."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""INSERT INTO alerts (company_name, alert_type, message, old_value, new_value)
VALUES (%s, %s, %s, %s, %s)""",
(company_name, alert_type, message, old_value, new_value),
)
conn.commit()
def list_alerts(self, limit: int = 50) -> List[Dict]:
"""List recent alerts."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s",
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
+26 -19
View File
@@ -1,9 +1,14 @@
"""LLM integration for patent analysis using OpenRouter.""" """LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI from openai import OpenAI
from SPARC import config from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from typing import Dict
logger = logging.getLogger(__name__)
class LLMAnalyzer: class LLMAnalyzer:
@@ -20,7 +25,7 @@ class LLMAnalyzer:
""" """
self.test_mode = test_mode self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = "anthropic/claude-3.5-sonnet" self.model = config.model
# Always initialize database client for storage and caching # Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url) self.db_client = DatabaseClient(config.database_url)
@@ -35,12 +40,13 @@ class LLMAnalyzer:
else: else:
self.client = None self.client = None
def analyze_patent_content(self, patent_content: str, company_name: str) -> str: def analyze_patent_content(self, patent_content: str, company_name: str, model: str | None = None) -> str:
"""Analyze patent content to estimate company innovation and performance. """Analyze patent content to estimate company innovation and performance.
Args: Args:
patent_content: Minimized patent text (abstract, claims, summary) patent_content: Minimized patent text (abstract, claims, summary)
company_name: Name of the company for context company_name: Name of the company for context
model: Optional model override (e.g. "openai/gpt-4o"). Defaults to config.
Returns: Returns:
Analysis text describing innovation quality and potential impact Analysis text describing innovation quality and potential impact
@@ -58,12 +64,10 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage.""" Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
effective_model = model or self.model
if self.test_mode: if self.test_mode:
print("=" * 80) logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
return "[TEST MODE - No API call made]" return "[TEST MODE - No API call made]"
# Check cache first # Check cache first
@@ -80,7 +84,7 @@ Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals
response=cached["response"], response=cached["response"],
company_name=company_name, company_name=company_name,
analysis_type="single_patent", analysis_type="single_patent",
model=self.model, model=effective_model,
metadata={ metadata={
"patent_content_length": len(patent_content), "patent_content_length": len(patent_content),
"cache_hit": True, "cache_hit": True,
@@ -93,7 +97,7 @@ Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals
# Call API if no cache hit and client is available # Call API if no cache hit and client is available
if self.client: if self.client:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=effective_model,
max_tokens=1024, max_tokens=1024,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
) )
@@ -105,7 +109,7 @@ Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals
response=response_text, response=response_text,
company_name=company_name, company_name=company_name,
analysis_type="single_patent", analysis_type="single_patent",
model=self.model, model=effective_model,
metadata={"patent_content_length": len(patent_content)}, metadata={"patent_content_length": len(patent_content)},
token_usage={ token_usage={
"prompt_tokens": response.usage.prompt_tokens, "prompt_tokens": response.usage.prompt_tokens,
@@ -123,13 +127,13 @@ Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals
response=placeholder, response=placeholder,
company_name=company_name, company_name=company_name,
analysis_type="single_patent", analysis_type="single_patent",
model=self.model, model=effective_model,
metadata={"patent_content_length": len(patent_content), "pending": True} metadata={"patent_content_length": len(patent_content), "pending": True}
) )
return placeholder return placeholder
def analyze_patent_portfolio( def analyze_patent_portfolio(
self, patents_data: list[Dict[str, str]], company_name: str self, patents_data: list[Dict[str, str]], company_name: str, model: str | None = None
) -> str: ) -> str:
"""Analyze multiple patents to estimate overall company performance. """Analyze multiple patents to estimate overall company performance.
@@ -164,13 +168,16 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook.""" Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
effective_model = model or self.model
if self.test_mode: if self.test_mode:
print(prompt) logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
return "[TEST MODE]" return "[TEST MODE]"
metadata = { metadata = {
"patent_count": len(patents_data), "patent_count": len(patents_data),
"patent_ids": [p['patent_id'] for p in patents_data] "patent_ids": [p['patent_id'] for p in patents_data],
"model": effective_model,
} }
# Check cache first # Check cache first
@@ -187,7 +194,7 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
response=cached["response"], response=cached["response"],
company_name=company_name, company_name=company_name,
analysis_type="portfolio", analysis_type="portfolio",
model=self.model, model=effective_model,
metadata={ metadata={
**metadata, **metadata,
"cache_hit": True, "cache_hit": True,
@@ -201,7 +208,7 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
if self.client: if self.client:
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=effective_model,
max_tokens=2048, max_tokens=2048,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
) )
@@ -214,7 +221,7 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
response=response_text, response=response_text,
company_name=company_name, company_name=company_name,
analysis_type="portfolio", analysis_type="portfolio",
model=self.model, model=effective_model,
metadata=metadata, metadata=metadata,
token_usage={ token_usage={
"prompt_tokens": response.usage.prompt_tokens, "prompt_tokens": response.usage.prompt_tokens,
@@ -234,7 +241,7 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
response=placeholder, response=placeholder,
company_name=company_name, company_name=company_name,
analysis_type="portfolio", analysis_type="portfolio",
model=self.model, model=effective_model,
metadata={**metadata, "pending": True} metadata={**metadata, "pending": True}
) )
return placeholder return placeholder
+109
View File
@@ -0,0 +1,109 @@
"""Scheduled patent analysis for tracked companies.
Uses APScheduler to periodically re-analyze tracked companies and
detect significant changes in patent counts.
"""
import logging
import os
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.database import DatabaseClient
logger = logging.getLogger(__name__)
# Configurable via environment variable (in hours, default 24)
SCHEDULE_INTERVAL_HOURS = int(os.getenv("SCHEDULE_INTERVAL_HOURS", "24"))
# Patent count change threshold (percentage) to trigger an alert
CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20"))
def run_scheduled_analysis() -> None:
"""Re-analyze all tracked companies and check for significant changes."""
db = DatabaseClient(config.database_url)
db.connect()
db.initialize_schema()
tracked = db.list_tracked_companies()
if not tracked:
logger.info("No tracked companies configured; skipping scheduled analysis")
return
logger.info("Running scheduled analysis for %d tracked companies", len(tracked))
analyzer = CompanyAnalyzer(db_client=db)
for company_row in tracked:
name = company_row["company_name"]
old_count = company_row.get("last_patent_count", 0) or 0
try:
result = analyzer._analyze_company_safe(name)
if result.success:
new_count = result.patent_count
# Update tracking record
db.update_tracked_company(name, new_count)
# Check for significant change
if old_count > 0:
delta_pct = abs(new_count - old_count) / old_count * 100
if delta_pct >= CHANGE_THRESHOLD_PERCENT:
direction = "increased" if new_count > old_count else "decreased"
message = (
f"Patent count for {name} {direction} by {delta_pct:.0f}% "
f"({old_count} -> {new_count})"
)
logger.warning("ALERT: %s", message)
db.store_alert(
company_name=name,
alert_type="patent_count_change",
message=message,
old_value=old_count,
new_value=new_count,
)
elif new_count > 0:
# First analysis -- record baseline
logger.info("Baseline for %s: %d patents", name, new_count)
else:
logger.warning("Scheduled analysis failed for %s: %s", name, result.error)
except Exception as e:
logger.error("Error analyzing tracked company %s: %s", name, e)
db.close()
logger.info("Scheduled analysis complete")
def start_scheduler() -> None:
"""Start the APScheduler background scheduler.
Safe to call at application startup. If apscheduler is not installed,
the function logs a warning and returns without starting anything.
"""
try:
from apscheduler.schedulers.background import BackgroundScheduler
except ImportError:
logger.warning(
"apscheduler not installed; scheduled analysis disabled. "
"Install with: pip install apscheduler"
)
return
scheduler = BackgroundScheduler()
scheduler.add_job(
run_scheduled_analysis,
"interval",
hours=SCHEDULE_INTERVAL_HOURS,
id="scheduled_patent_analysis",
replace_existing=True,
)
scheduler.start()
logger.info(
"Scheduled patent analysis started (every %d hours, threshold %d%%)",
SCHEDULE_INTERVAL_HOURS,
CHANGE_THRESHOLD_PERCENT,
)
+55 -18
View File
@@ -1,12 +1,29 @@
import os import io
import serpapi import logging
from SPARC import config
import re import re
import pdfplumber # pip install pdfplumber
import requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict from typing import Dict
from SPARC.types import Patents, Patent
import pdfplumber # pip install pdfplumber
import requests
import serpapi
from SPARC import config
from SPARC.storage import StorageBackend, get_storage_backend
from SPARC.types import Patent, Patents
logger = logging.getLogger(__name__)
# Module-level storage instance (lazy-initialized)
_storage: StorageBackend | None = None
def _get_storage() -> StorageBackend:
global _storage
if _storage is None:
_storage = get_storage_backend()
return _storage
class SERP: class SERP:
def query(company: str, days_back: int = None) -> Patents: def query(company: str, days_back: int = None) -> Patents:
@@ -41,6 +58,7 @@ class SERP:
"tbs": date_filter, "tbs": date_filter,
"api_key": config.api_key, "api_key": config.api_key,
} }
logger.info("Querying Google Patents for '%s' (last %d days)", company, days_back)
search = serpapi.search(params) search = serpapi.search(params)
# Convert results to Patent objects, skipping any without PDF links # Convert results to Patent objects, skipping any without PDF links
patent_ids = [] patent_ids = []
@@ -49,13 +67,16 @@ class SERP:
pdf_link = patent.get("pdf") pdf_link = patent.get("pdf")
if pdf_link: if pdf_link:
patent_ids.append(Patent(patent_id=patent["publication_number"], pdf_link=pdf_link, summary=None)) patent_ids.append(Patent(patent_id=patent["publication_number"], pdf_link=pdf_link, summary=None))
# Patents without PDF links are skipped (see docstring for details) else:
logger.debug("Skipping patent %s (no PDF link)", patent.get("publication_number", "unknown"))
logger.info("Found %d patents with PDF links for '%s'", len(patent_ids), company)
return Patents(patents=patent_ids) return Patents(patents=patent_ids)
def save_patents(patent: Patent) -> Patent: def save_patents(patent: Patent) -> Patent:
""" """Save the patent PDF to storage, skipping download if already cached.
Save the patent PDF to the patents folder, skipping download if already cached.
Uses the configured storage backend (local filesystem or S3).
Args: Args:
patent: Patent object patent: Patent object
@@ -63,35 +84,51 @@ class SERP:
Returns: Returns:
Patent object with updated PDF path Patent object with updated PDF path
""" """
pdf_path = f"patents/{patent.patent_id}.pdf" storage = _get_storage()
os.makedirs("patents", exist_ok=True) key = f"{patent.patent_id}.pdf"
if not (os.path.exists(pdf_path) and os.path.getsize(pdf_path) > 0): if not storage.exists(key):
logger.info("Downloading PDF for %s", patent.patent_id)
response = requests.get(patent.pdf_link) response = requests.get(patent.pdf_link)
with open(pdf_path, "wb") as f: storage.write(key, response.content)
f.write(response.content) logger.debug("Saved %d bytes for %s", len(response.content), patent.patent_id)
else:
logger.debug("Using cached PDF for %s", patent.patent_id)
patent.pdf_path = pdf_path patent.pdf_path = storage.path_for(key)
return patent return patent
def parse_patent_pdf(pdf_path: str) -> Dict: def parse_patent_pdf(pdf_path: str) -> Dict:
"""Extract structured sections from patent PDF. """Extract structured sections from patent PDF.
Extracts all major sections from a patent PDF including abstract, Extracts all major sections from a patent PDF including abstract,
claims, summary, and detailed description. claims, summary, and detailed description. Supports both local file
paths and S3 URIs (s3://bucket/key).
Args: Args:
pdf_path: Path to the patent PDF file pdf_path: Local path or S3 URI to the patent PDF file
Returns: Returns:
Dictionary containing all extracted sections Dictionary containing all extracted sections
""" """
logger.debug("Parsing patent PDF: %s", pdf_path)
with pdfplumber.open(pdf_path) as pdf: if pdf_path.startswith("s3://"):
# Read from S3 via storage backend
storage = _get_storage()
# Extract key from "s3://bucket/key"
key = pdf_path.split("/", 3)[-1]
data = storage.read(key)
pdf_file: io.BytesIO | str = io.BytesIO(data)
else:
pdf_file = pdf_path
with pdfplumber.open(pdf_file) as pdf:
# Extract all text # Extract all text
full_text = "" full_text = ""
for page in pdf.pages: for page in pdf.pages:
full_text += page.extract_text() + "\n" full_text += page.extract_text() + "\n"
logger.debug("Extracted text from %d pages (%d chars)", len(pdf.pages), len(full_text))
# Define section patterns (common in patents) # Define section patterns (common in patents)
sections = { sections = {
+171
View File
@@ -0,0 +1,171 @@
"""Patent PDF storage abstraction.
Provides a unified interface for reading and writing patent PDF files,
with pluggable backends for local filesystem and S3-compatible object
storage (e.g., MinIO, AWS S3).
"""
import logging
import os
from abc import ABC, abstractmethod
from SPARC import config
logger = logging.getLogger(__name__)
class StorageBackend(ABC):
"""Abstract base class for patent PDF storage."""
@abstractmethod
def read(self, key: str) -> bytes:
"""Read a file by key.
Args:
key: Storage key (e.g., "US-12345678-B2.pdf")
Returns:
File contents as bytes.
Raises:
FileNotFoundError: If the file does not exist.
"""
@abstractmethod
def write(self, key: str, data: bytes) -> None:
"""Write data to storage.
Args:
key: Storage key (e.g., "US-12345678-B2.pdf")
data: File contents as bytes.
"""
@abstractmethod
def exists(self, key: str) -> bool:
"""Check if a file exists in storage.
Args:
key: Storage key.
Returns:
True if the file exists and has non-zero size.
"""
@abstractmethod
def path_for(self, key: str) -> str:
"""Return a path or URI suitable for downstream consumers.
For local storage this is a filesystem path; for S3 it is the
object key (callers that need a local file should use read()
and write to a temporary location).
"""
class LocalStorageBackend(StorageBackend):
"""Store patent PDFs on the local filesystem under a directory."""
def __init__(self, base_dir: str = "patents"):
self.base_dir = base_dir
os.makedirs(self.base_dir, exist_ok=True)
def _full_path(self, key: str) -> str:
return os.path.join(self.base_dir, key)
def read(self, key: str) -> bytes:
path = self._full_path(key)
if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
with open(path, "rb") as f:
return f.read()
def write(self, key: str, data: bytes) -> None:
path = self._full_path(key)
os.makedirs(os.path.dirname(path) or self.base_dir, exist_ok=True)
with open(path, "wb") as f:
f.write(data)
logger.debug("Wrote %d bytes to %s", len(data), path)
def exists(self, key: str) -> bool:
path = self._full_path(key)
return os.path.exists(path) and os.path.getsize(path) > 0
def path_for(self, key: str) -> str:
return self._full_path(key)
class S3StorageBackend(StorageBackend):
"""Store patent PDFs in an S3-compatible bucket."""
def __init__(
self,
bucket: str,
endpoint_url: str = "",
access_key: str = "",
secret_key: str = "",
):
import boto3
kwargs: dict = {}
if endpoint_url:
kwargs["endpoint_url"] = endpoint_url
if access_key and secret_key:
kwargs["aws_access_key_id"] = access_key
kwargs["aws_secret_access_key"] = secret_key
self.s3 = boto3.client("s3", **kwargs)
self.bucket = bucket
# Ensure bucket exists (useful for MinIO local dev)
try:
self.s3.head_bucket(Bucket=self.bucket)
except Exception:
try:
self.s3.create_bucket(Bucket=self.bucket)
logger.info("Created S3 bucket: %s", self.bucket)
except Exception as e:
logger.warning("Could not create bucket %s: %s", self.bucket, e)
def read(self, key: str) -> bytes:
try:
response = self.s3.get_object(Bucket=self.bucket, Key=key)
return response["Body"].read()
except self.s3.exceptions.NoSuchKey:
raise FileNotFoundError(f"S3 object not found: s3://{self.bucket}/{key}")
except Exception as e:
if "NoSuchKey" in str(e) or "404" in str(e):
raise FileNotFoundError(f"S3 object not found: s3://{self.bucket}/{key}")
raise
def write(self, key: str, data: bytes) -> None:
self.s3.put_object(
Bucket=self.bucket,
Key=key,
Body=data,
ContentType="application/pdf",
)
logger.debug("Wrote %d bytes to s3://%s/%s", len(data), self.bucket, key)
def exists(self, key: str) -> bool:
try:
response = self.s3.head_object(Bucket=self.bucket, Key=key)
return response["ContentLength"] > 0
except Exception:
return False
def path_for(self, key: str) -> str:
return f"s3://{self.bucket}/{key}"
def get_storage_backend() -> StorageBackend:
"""Factory: return the configured storage backend instance."""
backend = config.storage_backend.lower()
if backend == "s3":
logger.info("Using S3 storage backend (bucket=%s)", config.s3_bucket)
return S3StorageBackend(
bucket=config.s3_bucket,
endpoint_url=config.s3_endpoint_url,
access_key=config.s3_access_key,
secret_key=config.s3_secret_key,
)
logger.info("Using local storage backend")
return LocalStorageBackend()
+2 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass @dataclass
class Patent: class Patent:
patent_id: int patent_id: str
pdf_link: str pdf_link: str
pdf_path: str | None = None pdf_path: str | None = None
summary: dict | None = None summary: dict | None = None
@@ -24,6 +24,7 @@ class CompanyAnalysisResult:
patent_count: int patent_count: int
success: bool success: bool
error: str | None = None error: str | None = None
model: str | None = None
timestamp: datetime = field(default_factory=datetime.now) timestamp: datetime = field(default_factory=datetime.now)
+139
View File
@@ -0,0 +1,139 @@
"""Webhook notifications for job completion and alert events.
Sends JSON payloads to configured webhook URLs with retry logic.
Supports generic HTTP POST and Slack-compatible text payloads.
"""
import logging
import os
import time
from datetime import datetime
from typing import Any
import requests
logger = logging.getLogger(__name__)
# Comma-separated list of webhook URLs (env var based config)
_WEBHOOK_URLS_RAW = os.getenv("WEBHOOK_URLS", "")
WEBHOOK_URLS: list[str] = [
url.strip() for url in _WEBHOOK_URLS_RAW.split(",") if url.strip()
]
MAX_RETRIES = 3
BACKOFF_BASE = 2 # seconds
def _is_slack_url(url: str) -> bool:
"""Check if a URL looks like a Slack incoming webhook."""
return "hooks.slack.com" in url or "discord.com/api/webhooks" in url
def _build_payload(event_type: str, data: dict[str, Any], slack: bool = False) -> dict:
"""Build the webhook payload.
Args:
event_type: Type of event (e.g., "job_completed", "alert")
data: Event-specific data
slack: If True, wrap in Slack-compatible ``text`` format
Returns:
JSON-serializable payload dict
"""
payload = {
"event": event_type,
"timestamp": datetime.utcnow().isoformat() + "Z",
**data,
}
if slack:
# Build a human-readable summary for Slack/Discord
lines = [f"*[SPARC] {event_type}*"]
for key, value in data.items():
lines.append(f" {key}: {value}")
return {"text": "\n".join(lines)}
return payload
def _send_with_retry(url: str, payload: dict) -> bool:
"""Send a POST request with exponential backoff retry.
Args:
url: Webhook URL
payload: JSON payload to send
Returns:
True if delivered successfully, False after all retries exhausted
"""
for attempt in range(1, MAX_RETRIES + 1):
try:
response = requests.post(url, json=payload, timeout=10)
if response.status_code < 300:
logger.debug("Webhook delivered to %s (attempt %d)", url, attempt)
return True
logger.warning(
"Webhook %s returned %d (attempt %d/%d)",
url, response.status_code, attempt, MAX_RETRIES,
)
except requests.RequestException as e:
logger.warning(
"Webhook delivery failed for %s (attempt %d/%d): %s",
url, attempt, MAX_RETRIES, e,
)
if attempt < MAX_RETRIES:
wait = BACKOFF_BASE ** attempt
time.sleep(wait)
logger.error("Webhook permanently failed for %s after %d attempts", url, MAX_RETRIES)
return False
def notify(event_type: str, data: dict[str, Any]) -> None:
"""Fire all configured webhooks for an event.
Safe to call even when no webhooks are configured (returns immediately).
Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert")
data: Event data to include in the payload
"""
if not WEBHOOK_URLS:
return
for url in WEBHOOK_URLS:
slack = _is_slack_url(url)
payload = _build_payload(event_type, data, slack=slack)
_send_with_retry(url, payload)
def notify_job_completed(
job_id: str,
status: str,
total_companies: int,
successful: int,
failed: int,
) -> None:
"""Send notification when a batch job completes."""
notify("job_completed", {
"job_id": job_id,
"status": status,
"total_companies": total_companies,
"successful": successful,
"failed": failed,
"summary": f"Batch job {job_id}: {successful}/{total_companies} succeeded",
})
def notify_alert(
company_name: str,
alert_type: str,
message: str,
) -> None:
"""Send notification for a tracked company alert."""
notify("patent_alert", {
"company_name": company_name,
"alert_type": alert_type,
"message": message,
})
+34 -7
View File
@@ -3,15 +3,15 @@ services:
image: postgres:16-alpine image: postgres:16-alpine
container_name: sparc-postgres container_name: sparc-postgres
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: sparc POSTGRES_DB: ${POSTGRES_DB}
ports: ports:
- "5432:5432" - "5432:5432"
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"] test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"]
interval: 5s interval: 5s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@@ -22,7 +22,7 @@ services:
container_name: sparc-init-db container_name: sparc-init-db
command: python scripts/init_database.py command: python scripts/init_database.py
environment: environment:
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
@@ -35,9 +35,11 @@ services:
environment: environment:
API_KEY: ${API_KEY} API_KEY: ${API_KEY}
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY} OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
USE_CACHE: "true" USE_CACHE: "true"
JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production} JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production}
CORS_ORIGINS: ${CORS_ORIGINS:-}
APP_ENV: ${APP_ENV:-development}
ROOT_PATH: /api ROOT_PATH: /api
ports: ports:
- "8000:8000" - "8000:8000"
@@ -47,9 +49,32 @@ services:
init-db: init-db:
condition: service_completed_successfully condition: service_completed_successfully
volumes: volumes:
- ./patents:/app/patents - patent_data:/app/patents
restart: unless-stopped restart: unless-stopped
# Optional: MinIO for S3-compatible local object storage
# Enable by setting STORAGE_BACKEND=s3 in .env
minio:
image: minio/minio:latest
container_name: sparc-minio
command: server /data --console-address ":9001"
environment:
MINIO_ROOT_USER: ${AWS_ACCESS_KEY_ID:-minioadmin}
MINIO_ROOT_PASSWORD: ${AWS_SECRET_ACCESS_KEY:-minioadmin}
ports:
- "9000:9000"
- "9001:9001"
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 10s
timeout: 5s
retries: 3
restart: unless-stopped
profiles:
- s3
dashboard: dashboard:
build: ./frontend build: ./frontend
container_name: sparc-dashboard container_name: sparc-dashboard
@@ -61,3 +86,5 @@ services:
volumes: volumes:
postgres_data: postgres_data:
patent_data:
minio_data:
+76 -1
View File
@@ -276,7 +276,7 @@ The `docker-compose.yml` includes all services needed for production:
|---------|-----------|------|-------------| |---------|-----------|------|-------------|
| `postgres` | sparc-postgres | 5432 | PostgreSQL database | | `postgres` | sparc-postgres | 5432 | PostgreSQL database |
| `init-db` | sparc-init-db | - | One-time database initialization (seeds admin user) | | `init-db` | sparc-init-db | - | One-time database initialization (seeds admin user) |
| `api` | sparc-api | 8000 | FastAPI REST API with JWT auth | | `api` | sparc-api | 8000 | FastAPI REST API with JWT auth (patent PDFs stored in `patent_data` volume) |
| `dashboard` | sparc-dashboard | 8080 | React TypeScript web UI | | `dashboard` | sparc-dashboard | 8080 | React TypeScript web UI |
### Common Docker Compose Commands ### Common Docker Compose Commands
@@ -307,6 +307,81 @@ docker-compose restart api
--- ---
## Patent PDF Storage
The SPARC API downloads patent PDFs during analysis and stores them at `/app/patents` inside the container. These files are used for subsequent single-patent analysis requests and as a local cache to avoid re-downloading. If this directory is not persisted, all downloaded PDFs are lost when the container is recreated.
### Docker Compose (default)
The default `docker-compose.yml` declares a named volume called `patent_data` that is mounted at `/app/patents`:
```yaml
# In the api service:
volumes:
- patent_data:/app/patents
# At the top-level volumes section:
volumes:
patent_data:
```
This means PDFs survive `docker compose down` and `docker compose up` cycles. To remove patent data intentionally, run:
```bash
docker compose down -v # WARNING: also removes postgres_data
# or selectively:
docker volume rm sparc_patent_data
```
If you prefer a bind mount (e.g., for easy host-side access during development), replace the volume with:
```yaml
volumes:
- ./patents:/app/patents
```
### Kubernetes
For Kubernetes deployments, create a PersistentVolumeClaim and mount it into the API pod:
```yaml
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: sparc-patent-data
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 5Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: sparc-api
spec:
template:
spec:
containers:
- name: api
volumeMounts:
- name: patent-data
mountPath: /app/patents
volumes:
- name: patent-data
persistentVolumeClaim:
claimName: sparc-patent-data
```
Adjust the storage size based on expected patent volume. Each patent PDF is typically 1-5 MB.
### S3 Object Storage (alternative)
For production deployments that need shared or highly durable storage, set `STORAGE_BACKEND=s3` in your `.env` file. This stores patent PDFs in an S3-compatible bucket (AWS S3 or MinIO) instead of the local filesystem, eliminating the need for a persistent volume. See the S3/MinIO section in `.env.example` for configuration details.
---
## Troubleshooting ## Troubleshooting
### Database Connection Issues ### Database Connection Issues
+9
View File
@@ -7,6 +7,15 @@
<title>SPARC Dashboard</title> <title>SPARC Dashboard</title>
</head> </head>
<body> <body>
<script>
// Prevent FOUC: apply saved theme before first render
(function() {
var theme = localStorage.getItem('theme');
if (theme === 'dark' || (!theme && window.matchMedia('(prefers-color-scheme: dark)').matches)) {
document.documentElement.classList.add('dark');
}
})();
</script>
<div id="root"></div> <div id="root"></div>
<script type="module" src="/src/main.tsx"></script> <script type="module" src="/src/main.tsx"></script>
</body> </body>
+4728
View File
File diff suppressed because it is too large Load Diff
+5 -1
View File
@@ -7,12 +7,15 @@
"dev": "vite", "dev": "vite",
"build": "tsc -b && vite build", "build": "tsc -b && vite build",
"lint": "eslint .", "lint": "eslint .",
"generate": "openapi-typescript http://localhost:8000/api/openapi.json -o src/api/schema.d.ts",
"generate:local": "openapi-typescript src/api/openapi.json -o src/api/schema.d.ts",
"typecheck": "tsc --noEmit",
"preview": "vite preview" "preview": "vite preview"
}, },
"dependencies": { "dependencies": {
"@tanstack/react-query": "^5.51.0", "@tanstack/react-query": "^5.51.0",
"axios": "^1.7.2", "axios": "^1.7.2",
"lucide-react": "^0.400.0", "lucide-react": "^1.7.0",
"react": "^18.3.1", "react": "^18.3.1",
"react-dom": "^18.3.1", "react-dom": "^18.3.1",
"react-router-dom": "^6.24.0", "react-router-dom": "^6.24.0",
@@ -30,6 +33,7 @@
"globals": "^15.8.0", "globals": "^15.8.0",
"postcss": "^8.4.39", "postcss": "^8.4.39",
"tailwindcss": "^3.4.4", "tailwindcss": "^3.4.4",
"openapi-typescript": "^7.0.0",
"typescript": "~5.5.3", "typescript": "~5.5.3",
"typescript-eslint": "^8.0.0", "typescript-eslint": "^8.0.0",
"vite": "^5.3.3" "vite": "^5.3.3"
+5
View File
@@ -1,6 +1,7 @@
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom'; import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { AuthProvider } from './context/AuthContext'; import { AuthProvider } from './context/AuthContext';
import { ThemeProvider } from './context/ThemeContext';
import { Layout } from './components/Layout'; import { Layout } from './components/Layout';
import { ProtectedRoute } from './components/ProtectedRoute'; import { ProtectedRoute } from './components/ProtectedRoute';
import { Login } from './pages/Login'; import { Login } from './pages/Login';
@@ -10,6 +11,7 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics'; import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About'; import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers'; import { AdminUsers } from './pages/AdminUsers';
import { Compare } from './pages/Compare';
const queryClient = new QueryClient({ const queryClient = new QueryClient({
defaultOptions: { defaultOptions: {
@@ -22,6 +24,7 @@ const queryClient = new QueryClient({
function App() { function App() {
return ( return (
<ThemeProvider>
<QueryClientProvider client={queryClient}> <QueryClientProvider client={queryClient}>
<AuthProvider> <AuthProvider>
<BrowserRouter> <BrowserRouter>
@@ -41,6 +44,7 @@ function App() {
<Route path="/analysis" element={<Analysis />} /> <Route path="/analysis" element={<Analysis />} />
<Route path="/batch" element={<Batch />} /> <Route path="/batch" element={<Batch />} />
<Route path="/analytics" element={<AnalyticsPage />} /> <Route path="/analytics" element={<AnalyticsPage />} />
<Route path="/compare" element={<Compare />} />
<Route path="/about" element={<About />} /> <Route path="/about" element={<About />} />
{/* Admin routes */} {/* Admin routes */}
@@ -61,6 +65,7 @@ function App() {
</BrowserRouter> </BrowserRouter>
</AuthProvider> </AuthProvider>
</QueryClientProvider> </QueryClientProvider>
</ThemeProvider>
); );
} }
+71 -4
View File
@@ -89,29 +89,53 @@ export const authApi = {
}, },
}; };
// Model types
export interface ModelInfo {
id: string;
name: string;
provider: string;
}
export interface ModelsResponse {
models: ModelInfo[];
default: string;
}
// Analysis API // Analysis API
export const analysisApi = { export const analysisApi = {
analyzeCompany: async (companyName: string): Promise<CompanyAnalysis> => { analyzeCompany: async (companyName: string, model?: string): Promise<CompanyAnalysis> => {
const response = await api.get<CompanyAnalysis>(`/analyze/${encodeURIComponent(companyName)}`); const params = new URLSearchParams();
if (model) params.append('model', model);
const qs = params.toString();
const response = await api.get<CompanyAnalysis>(
`/analyze/${encodeURIComponent(companyName)}${qs ? `?${qs}` : ''}`
);
return response.data; return response.data;
}, },
analyzeBatch: async (companies: string[], maxWorkers = 3): Promise<BatchAnalysisResult> => { analyzeBatch: async (companies: string[], maxWorkers = 3, model?: string): Promise<BatchAnalysisResult> => {
const response = await api.post<BatchAnalysisResult>('/analyze/batch', { const response = await api.post<BatchAnalysisResult>('/analyze/batch', {
companies, companies,
max_workers: maxWorkers, max_workers: maxWorkers,
...(model ? { model } : {}),
}); });
return response.data; return response.data;
}, },
analyzeBatchAsync: async (companies: string[], maxWorkers = 3): Promise<JobStatus> => { analyzeBatchAsync: async (companies: string[], maxWorkers = 3, model?: string): Promise<JobStatus> => {
const response = await api.post<JobStatus>('/analyze/batch/async', { const response = await api.post<JobStatus>('/analyze/batch/async', {
companies, companies,
max_workers: maxWorkers, max_workers: maxWorkers,
...(model ? { model } : {}),
}); });
return response.data; return response.data;
}, },
listModels: async (): Promise<ModelsResponse> => {
const response = await api.get<ModelsResponse>('/models');
return response.data;
},
getJobStatus: async (jobId: string): Promise<JobStatus> => { getJobStatus: async (jobId: string): Promise<JobStatus> => {
const response = await api.get<JobStatus>(`/jobs/${jobId}`); const response = await api.get<JobStatus>(`/jobs/${jobId}`);
return response.data; return response.data;
@@ -126,12 +150,55 @@ export const analysisApi = {
}, },
}; };
// Export API
export const exportApi = {
exportCsv: async (companyName: string): Promise<void> => {
const response = await api.get(`/export/${encodeURIComponent(companyName)}`, {
responseType: 'blob',
});
const url = window.URL.createObjectURL(new Blob([response.data]));
const link = document.createElement('a');
link.href = url;
link.setAttribute('download', `sparc_${companyName.toLowerCase().replace(/\s+/g, '_')}_export.csv`);
document.body.appendChild(link);
link.click();
link.remove();
window.URL.revokeObjectURL(url);
},
exportPdf: async (companyName: string): Promise<void> => {
const response = await api.get(`/export/${encodeURIComponent(companyName)}/pdf`, {
responseType: 'blob',
});
const safeName = companyName.toLowerCase().replace(/\s+/g, '_');
const date = new Date().toISOString().split('T')[0];
const url = window.URL.createObjectURL(new Blob([response.data], { type: 'application/pdf' }));
const link = document.createElement('a');
link.href = url;
link.setAttribute('download', `${safeName}-analysis-${date}.pdf`);
document.body.appendChild(link);
link.click();
link.remove();
window.URL.revokeObjectURL(url);
},
};
// Analytics API // Analytics API
export interface TrendData {
by_month: Array<{ month: string; company_name: string; count: number }>;
by_type_over_time: Array<{ month: string; analysis_type: string; count: number }>;
period_days: number;
}
export const analyticsApi = { export const analyticsApi = {
getAnalytics: async (days = 30): Promise<Analytics> => { getAnalytics: async (days = 30): Promise<Analytics> => {
const response = await api.get<Analytics>(`/analytics?days=${days}`); const response = await api.get<Analytics>(`/analytics?days=${days}`);
return response.data; return response.data;
}, },
getTrends: async (days = 90): Promise<TrendData> => {
const response = await api.get<TrendData>(`/analytics/trends?days=${days}`);
return response.data;
},
}; };
// Admin API // Admin API
File diff suppressed because it is too large Load Diff
+975
View File
@@ -0,0 +1,975 @@
/**
* This file was auto-generated by openapi-typescript.
* Do not make direct changes to the file.
*/
export interface paths {
"/auth/register": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Register
* @description Register a new user.
*
* The first registered user automatically becomes an admin.
*/
post: operations["register_auth_register_post"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/auth/login": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Login
* @description Authenticate user and return JWT tokens.
*/
post: operations["login_auth_login_post"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/auth/refresh": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Refresh Token
* @description Refresh access token using refresh token.
*/
post: operations["refresh_token_auth_refresh_post"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/auth/me": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Me
* @description Get current authenticated user.
*/
get: operations["get_me_auth_me_get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/admin/users": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* List Users
* @description List all users (admin only).
*/
get: operations["list_users_admin_users_get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/admin/users/{user_id}/role": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
/**
* Update User Role
* @description Update a user's role (admin only).
*/
patch: operations["update_user_role_admin_users__user_id__role_patch"];
trace?: never;
};
"/admin/users/{user_id}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post?: never;
/**
* Delete User
* @description Delete a user (admin only).
*/
delete: operations["delete_user_admin_users__user_id__delete"];
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/analytics": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Analytics
* @description Get analytics data (authenticated users only).
*/
get: operations["get_analytics_analytics_get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/health": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Health Check
* @description Check API health status.
*/
get: operations["health_check_health_get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/analyze/{company_name}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Analyze Company
* @description Analyze a single company's patent portfolio.
*
* This endpoint retrieves recent patents for the specified company,
* parses them, and uses AI to generate a comprehensive analysis.
*
* Args:
* company_name: Name of the company to analyze (e.g., "nvidia", "intel")
*
* Returns:
* Analysis results including patent count, AI insights, and success status
*/
get: operations["analyze_company_analyze__company_name__get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/analyze/batch": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Analyze Companies Batch
* @description Analyze multiple companies' patent portfolios.
*
* Processes companies concurrently for improved performance.
* Limited to 20 companies per request.
*
* Args:
* request: List of company names and optional worker count
*
* Returns:
* Batch results with individual company analyses and summary statistics
*/
post: operations["analyze_companies_batch_analyze_batch_post"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/analyze/batch/async": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Analyze Companies Async
* @description Start an asynchronous batch analysis job.
*
* Returns immediately with a job ID that can be used to poll for status.
* Useful for large batch analyses that may take a long time.
*
* Args:
* request: List of company names and optional worker count
*
* Returns:
* Job status with job_id for polling
*/
post: operations["analyze_companies_async_analyze_batch_async_post"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/jobs/{job_id}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Job Status
* @description Get the status of a background analysis job.
*
* Args:
* job_id: The job ID returned from the async batch endpoint
*
* Returns:
* Current job status including progress and results when complete
*/
get: operations["get_job_status_jobs__job_id__get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/jobs": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* List Jobs
* @description List all analysis jobs.
*
* Args:
* status: Optional filter by job status
* limit: Maximum number of jobs to return (default 10, max 100)
*
* Returns:
* List of job statuses
*/
get: operations["list_jobs_jobs_get"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
}
export type webhooks = Record<string, never>;
export interface components {
schemas: {
/**
* AnalyticsResponse
* @description Analytics response model.
*/
AnalyticsResponse: {
/** Total Messages */
total_messages: number;
/** By Company */
by_company: {
[key: string]: unknown;
}[];
/** By Type */
by_type: {
[key: string]: unknown;
}[];
/** Period Days */
period_days: number;
};
/**
* BatchAnalysisRequest
* @description Request model for batch company analysis.
*/
BatchAnalysisRequest: {
/**
* Companies
* @description List of company names to analyze
*/
companies: string[];
/**
* Max Workers
* @description Max concurrent analyses
* @default 3
*/
max_workers: number;
};
/**
* BatchAnalysisResponse
* @description Response model for batch company analysis.
*/
BatchAnalysisResponse: {
/** Results */
results: components["schemas"]["CompanyAnalysisResponse"][];
/** Total Companies */
total_companies: number;
/** Successful */
successful: number;
/** Failed */
failed: number;
/**
* Timestamp
* Format: date-time
*/
timestamp: string;
};
/**
* CompanyAnalysisResponse
* @description Response model for single company analysis.
*/
CompanyAnalysisResponse: {
/** Company Name */
company_name: string;
/** Analysis */
analysis: string;
/** Patent Count */
patent_count: number;
/** Success */
success: boolean;
/** Error */
error?: string | null;
/**
* Timestamp
* Format: date-time
*/
timestamp: string;
};
/** HTTPValidationError */
HTTPValidationError: {
/** Detail */
detail?: components["schemas"]["ValidationError"][];
};
/**
* HealthResponse
* @description Health check response.
*/
HealthResponse: {
/** Status */
status: string;
/** Version */
version: string;
/**
* Timestamp
* Format: date-time
*/
timestamp: string;
};
/**
* JobStatus
* @description Status of a background analysis job.
*/
JobStatus: {
/** Job Id */
job_id: string;
/** Status */
status: string;
/** Progress */
progress: number;
/** Total Companies */
total_companies: number;
/** Completed Companies */
completed_companies: number;
result?: components["schemas"]["BatchAnalysisResponse"] | null;
/** Error */
error?: string | null;
};
/**
* LoginRequest
* @description User login request.
*/
LoginRequest: {
/**
* Email
* Format: email
*/
email: string;
/** Password */
password: string;
};
/**
* RefreshRequest
* @description Token refresh request.
*/
RefreshRequest: {
/** Refresh Token */
refresh_token: string;
};
/**
* RegisterRequest
* @description User registration request.
*/
RegisterRequest: {
/**
* Email
* Format: email
*/
email: string;
/**
* Password
* @description Password (min 8 characters)
*/
password: string;
};
/**
* TokenResponse
* @description Token response model.
*/
TokenResponse: {
/** Access Token */
access_token: string;
/** Refresh Token */
refresh_token: string;
/**
* Token Type
* @default bearer
*/
token_type: string;
};
/**
* UpdateRoleRequest
* @description Update user role request.
*/
UpdateRoleRequest: {
/** Role */
role: string;
};
/**
* UserResponse
* @description User response model.
*/
UserResponse: {
/** Id */
id: number;
/** Email */
email: string;
/** Role */
role: string;
/**
* Created At
* Format: date-time
*/
created_at: string;
};
/** ValidationError */
ValidationError: {
/** Location */
loc: (string | number)[];
/** Message */
msg: string;
/** Error Type */
type: string;
/** Input */
input?: unknown;
/** Context */
ctx?: Record<string, never>;
};
};
responses: never;
parameters: never;
requestBodies: never;
headers: never;
pathItems: never;
}
export type $defs = Record<string, never>;
export interface operations {
register_auth_register_post: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["RegisterRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["UserResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
login_auth_login_post: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["LoginRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["TokenResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
refresh_token_auth_refresh_post: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["RefreshRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["TokenResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_me_auth_me_get: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["UserResponse"];
};
};
};
};
list_users_admin_users_get: {
parameters: {
query?: {
limit?: number;
offset?: number;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["UserResponse"][];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
update_user_role_admin_users__user_id__role_patch: {
parameters: {
query?: never;
header?: never;
path: {
user_id: number;
};
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["UpdateRoleRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["UserResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_user_admin_users__user_id__delete: {
parameters: {
query?: never;
header?: never;
path: {
user_id: number;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_analytics_analytics_get: {
parameters: {
query?: {
days?: number;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["AnalyticsResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
health_check_health_get: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HealthResponse"];
};
};
};
};
analyze_company_analyze__company_name__get: {
parameters: {
query?: never;
header?: never;
path: {
company_name: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["CompanyAnalysisResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
analyze_companies_batch_analyze_batch_post: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["BatchAnalysisRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["BatchAnalysisResponse"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
analyze_companies_async_analyze_batch_async_post: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["BatchAnalysisRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JobStatus"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_job_status_jobs__job_id__get: {
parameters: {
query?: never;
header?: never;
path: {
job_id: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JobStatus"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
list_jobs_jobs_get: {
parameters: {
query?: {
/** @description Filter by status: pending, running, completed, failed */
status?: string | null;
limit?: number;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JobStatus"][];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
}
+12 -2
View File
@@ -1,9 +1,11 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext'; import { useAuth } from '../context/AuthContext';
import { Search, Layers, BarChart3, Info, Users, LogOut } from 'lucide-react'; import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react';
export function Layout() { export function Layout() {
const { user, isAdmin, logout } = useAuth(); const { user, isAdmin, logout } = useAuth();
const { theme, toggleTheme } = useTheme();
const navigate = useNavigate(); const navigate = useNavigate();
const handleLogout = () => { const handleLogout = () => {
@@ -15,6 +17,7 @@ export function Layout() {
{ to: '/analysis', icon: Search, label: 'Analysis' }, { to: '/analysis', icon: Search, label: 'Analysis' },
{ to: '/batch', icon: Layers, label: 'Batch' }, { to: '/batch', icon: Layers, label: 'Batch' },
{ to: '/analytics', icon: BarChart3, label: 'Analytics' }, { to: '/analytics', icon: BarChart3, label: 'Analytics' },
{ to: '/compare', icon: GitCompareArrows, label: 'Compare' },
{ to: '/about', icon: Info, label: 'About' }, { to: '/about', icon: Info, label: 'About' },
]; ];
@@ -23,7 +26,7 @@ export function Layout() {
} }
return ( return (
<div className="min-h-screen bg-gradient-to-br from-bg-dark to-indigo-950"> <div className="min-h-screen bg-gradient-to-br from-bg-dark to-slate-100 dark:to-indigo-950">
{/* Header */} {/* Header */}
<header className="bg-bg-card/80 backdrop-blur-lg border-b border-primary/20"> <header className="bg-bg-card/80 backdrop-blur-lg border-b border-primary/20">
<div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8"> <div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8">
@@ -63,6 +66,13 @@ export function Layout() {
{/* User menu */} {/* User menu */}
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<button
onClick={toggleTheme}
className="p-2 rounded-lg text-text-secondary hover:text-text-primary hover:bg-bg-card-hover transition-all"
aria-label={theme === 'dark' ? 'Switch to light mode' : 'Switch to dark mode'}
>
{theme === 'dark' ? <Sun size={18} /> : <Moon size={18} />}
</button>
<div className="text-right hidden sm:block"> <div className="text-right hidden sm:block">
<div className="text-sm font-medium text-text-primary">{user?.email}</div> <div className="text-sm font-medium text-text-primary">{user?.email}</div>
<div className="text-xs text-text-secondary capitalize">{user?.role}</div> <div className="text-xs text-text-secondary capitalize">{user?.role}</div>
+1 -1
View File
@@ -12,7 +12,7 @@ export function ProtectedRoute({ children, requireAdmin = false }: ProtectedRout
if (isLoading) { if (isLoading) {
return ( return (
<div className="min-h-screen bg-gradient-to-br from-bg-dark to-indigo-950 flex items-center justify-center"> <div className="min-h-screen bg-gradient-to-br from-bg-dark to-slate-100 dark:to-indigo-950 flex items-center justify-center">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div> <div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div> </div>
); );
+48
View File
@@ -0,0 +1,48 @@
import { createContext, useContext, useEffect, useState } from 'react';
type Theme = 'light' | 'dark';
interface ThemeContextType {
theme: Theme;
toggleTheme: () => void;
}
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
function getInitialTheme(): Theme {
const stored = localStorage.getItem('theme');
if (stored === 'light' || stored === 'dark') return stored;
return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
}
export function ThemeProvider({ children }: { children: React.ReactNode }) {
const [theme, setTheme] = useState<Theme>(getInitialTheme);
useEffect(() => {
const root = document.documentElement;
if (theme === 'dark') {
root.classList.add('dark');
} else {
root.classList.remove('dark');
}
localStorage.setItem('theme', theme);
}, [theme]);
const toggleTheme = () => {
setTheme((prev) => (prev === 'dark' ? 'light' : 'dark'));
};
return (
<ThemeContext.Provider value={{ theme, toggleTheme }}>
{children}
</ThemeContext.Provider>
);
}
export function useTheme() {
const context = useContext(ThemeContext);
if (!context) {
throw new Error('useTheme must be used within a ThemeProvider');
}
return context;
}
+41
View File
@@ -0,0 +1,41 @@
import { useTheme } from './ThemeContext';
/**
* Returns theme-aware color values for recharts components.
*
* Recharts accepts only raw color strings (not CSS variables),
* so this hook bridges the Tailwind/CSS-variable theme system
* to the imperative recharts API.
*/
export function useChartTheme() {
const { theme } = useTheme();
const isDark = theme === 'dark';
return {
/** Axis tick and grid line stroke color */
axisStroke: isDark ? '#94a3b8' : '#64748b',
/** Tooltip container background */
tooltipBg: isDark ? '#1e293b' : '#ffffff',
/** Tooltip container border */
tooltipBorder: isDark
? '1px solid rgba(99, 102, 241, 0.3)'
: '1px solid rgba(99, 102, 241, 0.2)',
/** Tooltip label text color */
tooltipLabelColor: isDark ? '#f8fafc' : '#0f172a',
/** Tooltip item text color */
tooltipItemColor: isDark ? '#e2e8f0' : '#334155',
/** Convenience: full contentStyle object for recharts Tooltip */
tooltipContentStyle: {
backgroundColor: isDark ? '#1e293b' : '#ffffff',
border: isDark
? '1px solid rgba(99, 102, 241, 0.3)'
: '1px solid rgba(99, 102, 241, 0.2)',
borderRadius: '8px',
color: isDark ? '#f8fafc' : '#0f172a',
},
/** Convenience: labelStyle for recharts Tooltip */
tooltipLabelStyle: {
color: isDark ? '#f8fafc' : '#0f172a',
},
};
}
+22 -2
View File
@@ -2,6 +2,26 @@
@tailwind components; @tailwind components;
@tailwind utilities; @tailwind utilities;
/* Light mode (default) */
:root {
--color-bg-dark: #f1f5f9;
--color-bg-card: #ffffff;
--color-bg-card-hover: #e2e8f0;
--color-text-primary: #0f172a;
--color-text-secondary: #475569;
--color-border: #cbd5e1;
}
/* Dark mode */
.dark {
--color-bg-dark: #0f172a;
--color-bg-card: #1e293b;
--color-bg-card-hover: #334155;
--color-text-primary: #f8fafc;
--color-text-secondary: #94a3b8;
--color-border: #334155;
}
body { body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
-webkit-font-smoothing: antialiased; -webkit-font-smoothing: antialiased;
@@ -15,7 +35,7 @@ body {
} }
::-webkit-scrollbar-track { ::-webkit-scrollbar-track {
background: #1e293b; background: var(--color-bg-card);
} }
::-webkit-scrollbar-thumb { ::-webkit-scrollbar-thumb {
@@ -30,5 +50,5 @@ body {
/* Selection */ /* Selection */
::selection { ::selection {
background: rgba(99, 102, 241, 0.3); background: rgba(99, 102, 241, 0.3);
color: #f8fafc; color: var(--color-text-primary);
} }
+56 -6
View File
@@ -1,15 +1,21 @@
import { useState } from 'react'; import { useState } from 'react';
import { useMutation } from '@tanstack/react-query'; import { useMutation, useQuery } from '@tanstack/react-query';
import { analysisApi } from '../api/client'; import { analysisApi, exportApi } from '../api/client';
import { Search, CheckCircle, AlertCircle, Clock, FileText } from 'lucide-react'; import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react';
import type { CompanyAnalysis } from '../types'; import type { CompanyAnalysis } from '../types';
export function Analysis() { export function Analysis() {
const [companyName, setCompanyName] = useState(''); const [companyName, setCompanyName] = useState('');
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<CompanyAnalysis | null>(null); const [result, setResult] = useState<CompanyAnalysis | null>(null);
const modelsQuery = useQuery({
queryKey: ['models'],
queryFn: () => analysisApi.listModels(),
});
const mutation = useMutation({ const mutation = useMutation({
mutationFn: (name: string) => analysisApi.analyzeCompany(name), mutationFn: (name: string) => analysisApi.analyzeCompany(name, selectedModel || undefined),
onSuccess: (data) => setResult(data), onSuccess: (data) => setResult(data),
}); });
@@ -33,7 +39,8 @@ export function Analysis() {
</div> </div>
{/* Search Form */} {/* Search Form */}
<form onSubmit={handleSubmit} className="flex gap-4"> <form onSubmit={handleSubmit} className="space-y-4">
<div className="flex gap-4">
<div className="flex-1 relative"> <div className="flex-1 relative">
<Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} /> <Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} />
<input <input
@@ -58,6 +65,31 @@ export function Analysis() {
</> </>
)} )}
</button> </button>
</div>
{/* Model Selector */}
<div className="flex items-center gap-3">
<label className="text-sm font-medium text-text-secondary whitespace-nowrap">
LLM Model
</label>
<div className="relative flex-1 max-w-xs">
<select
value={selectedModel}
onChange={(e) => setSelectedModel(e.target.value)}
className="w-full appearance-none bg-bg-card/80 border border-primary/30 rounded-lg pl-3 pr-8 py-2 text-sm text-text-primary focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all cursor-pointer"
>
<option value="">
{modelsQuery.data ? `Default (${modelsQuery.data.default})` : 'Default'}
</option>
{modelsQuery.data?.models.map((m) => (
<option key={m.id} value={m.id}>
{m.name} ({m.provider})
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 -translate-y-1/2 text-text-secondary pointer-events-none" size={16} />
</div>
</div>
</form> </form>
{/* Error */} {/* Error */}
@@ -106,9 +138,27 @@ export function Analysis() {
{/* Analysis Content */} {/* Analysis Content */}
{result.success && result.analysis && ( {result.success && result.analysis && (
<div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6"> <div className="bg-bg-card/60 backdrop-blur-lg border border-primary/15 rounded-2xl p-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-4"> <div className="flex items-center justify-between border-b-2 border-primary/30 pb-2 mb-4">
<h3 className="text-lg font-semibold text-text-primary">
AI Analysis Results AI Analysis Results
</h3> </h3>
<div className="flex items-center gap-2">
<button
onClick={() => exportApi.exportCsv(result.company_name)}
className="flex items-center gap-2 text-sm bg-primary/20 hover:bg-primary/30 text-primary font-medium px-3 py-1.5 rounded-lg transition-colors"
>
<Download size={14} />
Export CSV
</button>
<button
onClick={() => exportApi.exportPdf(result.company_name)}
className="flex items-center gap-2 text-sm bg-primary/20 hover:bg-primary/30 text-primary font-medium px-3 py-1.5 rounded-lg transition-colors"
>
<FileText size={14} />
Export PDF
</button>
</div>
</div>
<div className="prose prose-invert max-w-none"> <div className="prose prose-invert max-w-none">
<div className="text-text-primary whitespace-pre-wrap leading-relaxed"> <div className="text-text-primary whitespace-pre-wrap leading-relaxed">
{result.analysis} {result.analysis}
+148 -23
View File
@@ -2,22 +2,52 @@ import { useState } from 'react';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { analyticsApi } from '../api/client'; import { analyticsApi } from '../api/client';
import { AlertCircle, Database } from 'lucide-react'; import { AlertCircle, Database } from 'lucide-react';
import { PieChart, Pie, Cell, BarChart, Bar, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts'; import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts';
import { useChartTheme } from '../context/useChartTheme';
const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6']; const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6'];
export function AnalyticsPage() { export function AnalyticsPage() {
const [days, setDays] = useState(30); const [days, setDays] = useState(30);
const chartTheme = useChartTheme();
const { data, isLoading, isError } = useQuery({ const { data, isLoading, isError, refetch } = useQuery({
queryKey: ['analytics', days], queryKey: ['analytics', days],
queryFn: () => analyticsApi.getAnalytics(days), queryFn: () => analyticsApi.getAnalytics(days),
}); });
const trendsQuery = useQuery({
queryKey: ['analytics-trends', days],
queryFn: () => analyticsApi.getTrends(days),
});
if (isLoading) { if (isLoading) {
return ( return (
<div className="flex items-center justify-center min-h-[400px]"> <div className="space-y-6">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div> <div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Analytics Dashboard
</h2>
<p className="text-text-secondary">Loading analytics data...</p>
</div>
{/* Skeleton cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
{[1, 2, 3].map((i) => (
<div key={i} className="bg-gradient-to-br from-primary/10 to-secondary/10 border border-primary/20 rounded-xl p-5 text-center animate-pulse">
<div className="h-9 w-16 bg-primary/20 rounded mx-auto mb-2" />
<div className="h-4 w-24 bg-primary/10 rounded mx-auto" />
</div>
))}
</div>
{/* Skeleton charts */}
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{[1, 2].map((i) => (
<div key={i} className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6 animate-pulse">
<div className="h-5 w-40 bg-primary/20 rounded mb-4" />
<div className="h-[300px] bg-primary/5 rounded" />
</div>
))}
</div>
</div> </div>
); );
} }
@@ -33,15 +63,18 @@ export function AnalyticsPage() {
<div className="bg-gradient-to-br from-primary/10 to-secondary/5 border border-primary/20 rounded-xl p-6"> <div className="bg-gradient-to-br from-primary/10 to-secondary/5 border border-primary/20 rounded-xl p-6">
<div className="flex items-center gap-3 text-warning mb-2"> <div className="flex items-center gap-3 text-warning mb-2">
<Database size={24} /> <Database size={24} />
<span className="font-semibold">Database Not Connected</span> <span className="font-semibold">Unable to Load Analytics</span>
</div> </div>
<p className="text-text-secondary"> <p className="text-text-secondary">
Set <code className="bg-bg-card px-2 py-1 rounded">USE_DATABASE=true</code> in your .env file to enable analytics tracking. Could not connect to the analytics database. Ensure PostgreSQL is running and
<code className="bg-bg-card px-2 py-1 rounded mx-1">DATABASE_URL</code> is configured correctly.
</p> </p>
</div> <button
<div className="flex items-center gap-2 bg-secondary/10 border border-secondary/20 text-secondary rounded-xl px-4 py-3"> onClick={() => refetch()}
<AlertCircle size={18} /> className="mt-3 text-sm bg-primary/20 hover:bg-primary/30 text-primary font-medium px-4 py-2 rounded-lg transition-colors"
<span>Analytics features require storing analysis results in PostgreSQL for historical tracking.</span> >
Retry
</button>
</div> </div>
</div> </div>
); );
@@ -129,11 +162,7 @@ export function AnalyticsPage() {
))} ))}
</Pie> </Pie>
<Tooltip <Tooltip
contentStyle={{ contentStyle={chartTheme.tooltipContentStyle}
backgroundColor: '#1e293b',
border: '1px solid rgba(99, 102, 241, 0.3)',
borderRadius: '8px',
}}
/> />
<Legend /> <Legend />
</PieChart> </PieChart>
@@ -147,15 +176,11 @@ export function AnalyticsPage() {
<h3 className="text-lg font-semibold text-text-primary mb-4">Analysis Types</h3> <h3 className="text-lg font-semibold text-text-primary mb-4">Analysis Types</h3>
<ResponsiveContainer width="100%" height={300}> <ResponsiveContainer width="100%" height={300}>
<BarChart data={typeData}> <BarChart data={typeData}>
<XAxis dataKey="name" stroke="#94a3b8" fontSize={12} /> <XAxis dataKey="name" stroke={chartTheme.axisStroke} fontSize={12} />
<YAxis stroke="#94a3b8" fontSize={12} /> <YAxis stroke={chartTheme.axisStroke} fontSize={12} />
<Tooltip <Tooltip
contentStyle={{ contentStyle={chartTheme.tooltipContentStyle}
backgroundColor: '#1e293b', labelStyle={chartTheme.tooltipLabelStyle}
border: '1px solid rgba(99, 102, 241, 0.3)',
borderRadius: '8px',
}}
labelStyle={{ color: '#f8fafc' }}
/> />
<Bar dataKey="count" fill="#6366f1" radius={[4, 4, 0, 0]} /> <Bar dataKey="count" fill="#6366f1" radius={[4, 4, 0, 0]} />
</BarChart> </BarChart>
@@ -163,6 +188,106 @@ export function AnalyticsPage() {
</div> </div>
)} )}
</div> </div>
{/* Trend Charts */}
{trendsQuery.data && (
<div className="space-y-6">
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2">
Trends Over Time
</h3>
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Patent count over time per company (line chart) */}
{trendsQuery.data.by_month.length > 0 && (() => {
// Pivot data: each month as a row, companies as columns
const companies = [...new Set(trendsQuery.data!.by_month.map(d => d.company_name))];
const months = [...new Set(trendsQuery.data!.by_month.map(d => d.month))].sort();
const pivoted = months.map(month => {
const row: Record<string, string | number> = { month };
for (const c of companies) {
const entry = trendsQuery.data!.by_month.find(d => d.month === month && d.company_name === c);
row[c] = entry?.count || 0;
}
return row;
});
return (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
<h4 className="text-md font-semibold text-text-primary mb-4">Analyses per Company Over Time</h4>
<ResponsiveContainer width="100%" height={300}>
<LineChart data={pivoted}>
<XAxis dataKey="month" stroke={chartTheme.axisStroke} fontSize={12} />
<YAxis stroke={chartTheme.axisStroke} fontSize={12} />
<Tooltip
contentStyle={chartTheme.tooltipContentStyle}
labelStyle={chartTheme.tooltipLabelStyle}
/>
<Legend />
{companies.map((company, idx) => (
<Line
key={company}
type="monotone"
dataKey={company}
stroke={COLORS[idx % COLORS.length]}
strokeWidth={2}
dot={{ r: 4 }}
name={company.toUpperCase()}
/>
))}
</LineChart>
</ResponsiveContainer>
</div>
);
})()}
{/* Analysis type distribution over time (stacked bar) */}
{trendsQuery.data.by_type_over_time.length > 0 && (() => {
const types = [...new Set(trendsQuery.data!.by_type_over_time.map(d => d.analysis_type))];
const months = [...new Set(trendsQuery.data!.by_type_over_time.map(d => d.month))].sort();
const pivoted = months.map(month => {
const row: Record<string, string | number> = { month };
for (const t of types) {
const entry = trendsQuery.data!.by_type_over_time.find(d => d.month === month && d.analysis_type === t);
row[t] = entry?.count || 0;
}
return row;
});
return (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
<h4 className="text-md font-semibold text-text-primary mb-4">Analysis Types Over Time</h4>
<ResponsiveContainer width="100%" height={300}>
<BarChart data={pivoted}>
<XAxis dataKey="month" stroke={chartTheme.axisStroke} fontSize={12} />
<YAxis stroke={chartTheme.axisStroke} fontSize={12} />
<Tooltip
contentStyle={chartTheme.tooltipContentStyle}
labelStyle={chartTheme.tooltipLabelStyle}
/>
<Legend />
{types.map((type, idx) => (
<Bar
key={type}
dataKey={type}
stackId="types"
fill={COLORS[idx % COLORS.length]}
name={type}
/>
))}
</BarChart>
</ResponsiveContainer>
</div>
);
})()}
</div>
{trendsQuery.data.by_month.length === 0 && (
<div className="text-text-secondary text-center py-8">
No trend data available yet. Run analyses over multiple days to see trends.
</div>
)}
</div>
)}
</div> </div>
); );
} }
+196 -14
View File
@@ -1,20 +1,37 @@
import { useState } from 'react'; import { useState } from 'react';
import { useMutation } from '@tanstack/react-query'; import { useMutation, useQuery } from '@tanstack/react-query';
import { analysisApi } from '../api/client'; import { analysisApi } from '../api/client';
import { Rocket, CheckCircle, AlertCircle, ChevronDown, ChevronUp } from 'lucide-react'; import { Rocket, CheckCircle, AlertCircle, ChevronDown, ChevronUp, RefreshCw, Inbox } from 'lucide-react';
import { BarChart, Bar, XAxis, YAxis, Tooltip, ResponsiveContainer, Cell } from 'recharts'; import { BarChart, Bar, XAxis, YAxis, Tooltip, ResponsiveContainer, Cell } from 'recharts';
import { useChartTheme } from '../context/useChartTheme';
import type { BatchAnalysisResult } from '../types'; import type { BatchAnalysisResult } from '../types';
export function Batch() { export function Batch() {
const [companiesInput, setCompaniesInput] = useState(''); const [companiesInput, setCompaniesInput] = useState('');
const [maxWorkers, setMaxWorkers] = useState(3); const [maxWorkers, setMaxWorkers] = useState(3);
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<BatchAnalysisResult | null>(null); const [result, setResult] = useState<BatchAnalysisResult | null>(null);
const [expandedItems, setExpandedItems] = useState<Set<string>>(new Set()); const [expandedItems, setExpandedItems] = useState<Set<string>>(new Set());
const chartTheme = useChartTheme();
const modelsQuery = useQuery({
queryKey: ['models'],
queryFn: () => analysisApi.listModels(),
});
const jobsQuery = useQuery({
queryKey: ['jobs'],
queryFn: () => analysisApi.listJobs(undefined, 20),
});
const mutation = useMutation({ const mutation = useMutation({
mutationFn: ({ companies, workers }: { companies: string[]; workers: number }) => mutationFn: ({ companies, workers }: { companies: string[]; workers: number }) =>
analysisApi.analyzeBatch(companies, workers), analysisApi.analyzeBatch(companies, workers, selectedModel || undefined),
onSuccess: (data) => setResult(data), onSuccess: (data) => {
setResult(data);
jobsQuery.refetch();
},
}); });
const handleSubmit = (e: React.FormEvent) => { const handleSubmit = (e: React.FormEvent) => {
@@ -85,6 +102,29 @@ export function Batch() {
<div className="text-center text-text-primary font-semibold">{maxWorkers}</div> <div className="text-center text-text-primary font-semibold">{maxWorkers}</div>
</div> </div>
<div>
<label className="block text-sm font-medium text-text-secondary mb-2">
LLM Model
</label>
<div className="relative">
<select
value={selectedModel}
onChange={(e) => setSelectedModel(e.target.value)}
className="w-full appearance-none bg-bg-card/80 border border-primary/30 rounded-lg pl-3 pr-8 py-2 text-sm text-text-primary focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all cursor-pointer"
>
<option value="">
{modelsQuery.data ? `Default (${modelsQuery.data.default})` : 'Default'}
</option>
{modelsQuery.data?.models.map((m) => (
<option key={m.id} value={m.id}>
{m.name} ({m.provider})
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 -translate-y-1/2 text-text-secondary pointer-events-none" size={16} />
</div>
</div>
<button <button
type="submit" type="submit"
disabled={mutation.isPending || !companiesInput.trim()} disabled={mutation.isPending || !companiesInput.trim()}
@@ -114,9 +154,38 @@ export function Batch() {
{/* Error */} {/* Error */}
{mutation.isError && ( {mutation.isError && (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3"> <div className="bg-error/10 border border-error/20 rounded-xl px-4 py-3">
<div className="flex items-center gap-2 text-error">
<AlertCircle size={18} /> <AlertCircle size={18} />
<span>Batch analysis failed. Please try again.</span> <span className="font-semibold">Batch analysis failed</span>
</div>
<p className="text-text-secondary text-sm mt-1 ml-7">
{mutation.error instanceof Error ? mutation.error.message : 'An unexpected error occurred.'}
{' '}Check your connection and try again.
</p>
<div className="ml-7 mt-2 flex items-center gap-3">
<button
onClick={() => {
const companies = companiesInput
.split(/[,\n]/)
.map((c) => c.trim())
.filter((c) => c.length > 0);
if (companies.length > 0) {
mutation.mutate({ companies, workers: maxWorkers });
}
}}
className="text-sm text-primary hover:text-primary-dark underline flex items-center gap-1"
>
<RefreshCw size={14} />
Retry
</button>
<button
onClick={() => mutation.reset()}
className="text-sm text-text-secondary hover:text-text-primary underline"
>
Dismiss
</button>
</div>
</div> </div>
)} )}
@@ -144,15 +213,11 @@ export function Batch() {
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6"> <div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
<ResponsiveContainer width="100%" height={300}> <ResponsiveContainer width="100%" height={300}>
<BarChart data={chartData}> <BarChart data={chartData}>
<XAxis dataKey="name" stroke="#94a3b8" fontSize={12} /> <XAxis dataKey="name" stroke={chartTheme.axisStroke} fontSize={12} />
<YAxis stroke="#94a3b8" fontSize={12} /> <YAxis stroke={chartTheme.axisStroke} fontSize={12} />
<Tooltip <Tooltip
contentStyle={{ contentStyle={chartTheme.tooltipContentStyle}
backgroundColor: '#1e293b', labelStyle={chartTheme.tooltipLabelStyle}
border: '1px solid rgba(99, 102, 241, 0.3)',
borderRadius: '8px',
}}
labelStyle={{ color: '#f8fafc' }}
/> />
<Bar dataKey="patents" radius={[4, 4, 0, 0]}> <Bar dataKey="patents" radius={[4, 4, 0, 0]}>
{chartData.map((entry, index) => ( {chartData.map((entry, index) => (
@@ -218,6 +283,123 @@ export function Batch() {
</div> </div>
</div> </div>
)} )}
{/* Job History */}
<div>
<h3 className="text-lg font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-4">
Job History
</h3>
{/* Loading skeleton */}
{jobsQuery.isLoading && (
<div className="space-y-3">
{[...Array(3)].map((_, i) => (
<div
key={i}
className="bg-bg-card/60 border border-primary/15 rounded-xl p-4 animate-pulse"
>
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className="h-5 w-5 rounded-full bg-primary/20" />
<div className="h-4 w-32 rounded bg-primary/20" />
<div className="h-4 w-20 rounded bg-primary/10" />
</div>
<div className="h-6 w-20 rounded-full bg-primary/15" />
</div>
<div className="mt-3 flex gap-4">
<div className="h-3 w-24 rounded bg-primary/10" />
<div className="h-3 w-16 rounded bg-primary/10" />
</div>
</div>
))}
</div>
)}
{/* Job history error */}
{jobsQuery.isError && (
<div className="bg-error/10 border border-error/20 rounded-xl px-4 py-3">
<div className="flex items-center gap-2 text-error">
<AlertCircle size={18} />
<span className="font-semibold">Failed to load job history</span>
</div>
<p className="text-text-secondary text-sm mt-1 ml-7">
{jobsQuery.error instanceof Error ? jobsQuery.error.message : 'Could not retrieve past jobs.'}
</p>
<button
onClick={() => jobsQuery.refetch()}
className="ml-7 mt-2 text-sm text-primary hover:text-primary-dark underline flex items-center gap-1"
>
<RefreshCw size={14} />
Retry
</button>
</div>
)}
{/* Empty state */}
{jobsQuery.isSuccess && jobsQuery.data.length === 0 && !result && (
<div className="bg-bg-card/60 border border-primary/15 border-dashed rounded-xl p-8 text-center">
<Inbox className="mx-auto text-text-secondary/40 mb-3" size={40} />
<p className="text-text-secondary font-medium">No batch jobs yet</p>
<p className="text-text-secondary/70 text-sm mt-1">
Submit a batch analysis above to get started. Your job history will appear here.
</p>
</div>
)}
{/* Job list */}
{jobsQuery.isSuccess && jobsQuery.data.length > 0 && (
<div className="space-y-3">
{jobsQuery.data.map((job) => (
<div
key={job.job_id}
className="bg-bg-card/60 border border-primary/15 rounded-xl p-4"
>
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
{job.status === 'completed' && <CheckCircle className="text-success" size={18} />}
{job.status === 'failed' && <AlertCircle className="text-error" size={18} />}
{(job.status === 'pending' || job.status === 'running') && (
<div className="animate-spin rounded-full h-[18px] w-[18px] border-t-2 border-b-2 border-secondary" />
)}
<span className="font-mono text-sm text-text-primary">{job.job_id.slice(0, 8)}</span>
<span className="text-text-secondary text-sm">
{job.total_companies} {job.total_companies === 1 ? 'company' : 'companies'}
</span>
</div>
<span
className={`text-xs font-semibold px-2.5 py-1 rounded-full ${
job.status === 'completed'
? 'bg-success/15 text-success'
: job.status === 'failed'
? 'bg-error/15 text-error'
: 'bg-secondary/15 text-secondary'
}`}
>
{job.status}
</span>
</div>
{(job.status === 'running' || job.status === 'pending') && job.total_companies > 0 && (
<div className="mt-3">
<div className="flex items-center justify-between text-xs text-text-secondary mb-1">
<span>Progress</span>
<span>{job.completed_companies}/{job.total_companies}</span>
</div>
<div className="h-1.5 bg-bg-dark rounded-full overflow-hidden">
<div
className="h-full bg-gradient-to-r from-primary to-secondary rounded-full transition-all duration-300"
style={{ width: `${(job.completed_companies / job.total_companies) * 100}%` }}
/>
</div>
</div>
)}
{job.status === 'failed' && job.error && (
<p className="mt-2 text-sm text-error/80">{job.error}</p>
)}
</div>
))}
</div>
)}
</div>
</div> </div>
); );
} }
+161
View File
@@ -0,0 +1,161 @@
import { useState } from 'react';
import { useSearchParams } from 'react-router-dom';
import { useQuery } from '@tanstack/react-query';
import { analysisApi } from '../api/client';
import { GitCompareArrows, AlertCircle, FileText, Clock } from 'lucide-react';
import type { CompanyAnalysis } from '../types';
function CompanyPanel({ data, isLoading, isError }: { data?: CompanyAnalysis; isLoading: boolean; isError: boolean }) {
if (isLoading) {
return (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6 animate-pulse">
<div className="h-6 w-32 bg-primary/20 rounded mb-4" />
<div className="space-y-3">
<div className="h-4 bg-primary/10 rounded w-full" />
<div className="h-4 bg-primary/10 rounded w-3/4" />
<div className="h-4 bg-primary/10 rounded w-5/6" />
</div>
</div>
);
}
if (isError) {
return (
<div className="bg-error/10 border border-error/20 rounded-2xl p-6">
<div className="flex items-center gap-2 text-error">
<AlertCircle size={18} />
<span>Failed to load analysis. Check the company name and try again.</span>
</div>
</div>
);
}
if (!data) return null;
return (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6 space-y-4">
<h3 className="text-lg font-bold text-text-primary border-b-2 border-primary/30 pb-2">
{data.company_name.toUpperCase()}
</h3>
<div className="grid grid-cols-2 gap-3">
<div className="bg-primary/10 rounded-lg p-3 text-center">
<FileText className="mx-auto mb-1 text-primary" size={18} />
<div className="text-xl font-bold text-text-primary">{data.patent_count}</div>
<div className="text-xs text-text-secondary uppercase">Patents</div>
</div>
<div className="bg-primary/10 rounded-lg p-3 text-center">
<Clock className="mx-auto mb-1 text-primary" size={18} />
<div className="text-sm font-medium text-text-primary">
{new Date(data.timestamp).toLocaleDateString()}
</div>
<div className="text-xs text-text-secondary uppercase">Analyzed</div>
</div>
</div>
{data.success && data.analysis ? (
<div className="text-text-primary whitespace-pre-wrap leading-relaxed text-sm">
{data.analysis}
</div>
) : (
<div className="text-error text-sm">{data.error || 'Analysis not available'}</div>
)}
</div>
);
}
export function Compare() {
const [searchParams, setSearchParams] = useSearchParams();
const [companyA, setCompanyA] = useState(searchParams.get('a') || '');
const [companyB, setCompanyB] = useState(searchParams.get('b') || '');
const queryA = searchParams.get('a') || '';
const queryB = searchParams.get('b') || '';
const resultA = useQuery({
queryKey: ['analyze', queryA],
queryFn: () => analysisApi.analyzeCompany(queryA),
enabled: !!queryA,
});
const resultB = useQuery({
queryKey: ['analyze', queryB],
queryFn: () => analysisApi.analyzeCompany(queryB),
enabled: !!queryB,
});
const handleCompare = (e: React.FormEvent) => {
e.preventDefault();
const a = companyA.trim();
const b = companyB.trim();
if (a && b) {
setSearchParams({ a, b });
}
};
return (
<div className="space-y-6">
{/* Header */}
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Portfolio Comparison
</h2>
<p className="text-text-secondary">
Compare patent portfolios of two companies side by side.
</p>
</div>
{/* Input Form */}
<form onSubmit={handleCompare} className="flex flex-col sm:flex-row gap-3 items-end">
<div className="flex-1">
<label className="block text-sm font-medium text-text-secondary mb-1">Company A</label>
<input
type="text"
value={companyA}
onChange={(e) => setCompanyA(e.target.value)}
placeholder="e.g. nvidia"
className="w-full bg-bg-card/80 border border-primary/30 rounded-xl px-4 py-2.5 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all"
/>
</div>
<div className="flex-1">
<label className="block text-sm font-medium text-text-secondary mb-1">Company B</label>
<input
type="text"
value={companyB}
onChange={(e) => setCompanyB(e.target.value)}
placeholder="e.g. intel"
className="w-full bg-bg-card/80 border border-primary/30 rounded-xl px-4 py-2.5 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all"
/>
</div>
<button
type="submit"
disabled={!companyA.trim() || !companyB.trim() || resultA.isLoading || resultB.isLoading}
className="bg-gradient-to-r from-primary to-primary-dark text-white font-semibold py-2.5 px-6 rounded-xl hover:shadow-lg hover:shadow-primary/30 transition-all disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
<GitCompareArrows size={18} />
Compare
</button>
</form>
{/* Comparison Panels */}
{(queryA || queryB) && (
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{queryA && (
<CompanyPanel
data={resultA.data}
isLoading={resultA.isLoading}
isError={resultA.isError}
/>
)}
{queryB && (
<CompanyPanel
data={resultB.data}
isLoading={resultB.isLoading}
isError={resultB.isError}
/>
)}
</div>
)}
</div>
);
}
+1 -1
View File
@@ -31,7 +31,7 @@ export function Login() {
}; };
return ( return (
<div className="min-h-screen bg-gradient-to-br from-bg-dark to-indigo-950 flex items-center justify-center px-4"> <div className="min-h-screen bg-gradient-to-br from-bg-dark to-slate-100 dark:to-indigo-950 flex items-center justify-center px-4">
<div className="w-full max-w-md"> <div className="w-full max-w-md">
{/* Brand */} {/* Brand */}
<div className="text-center mb-8"> <div className="text-center mb-8">
+1 -1
View File
@@ -40,7 +40,7 @@ export function Register() {
}; };
return ( return (
<div className="min-h-screen bg-gradient-to-br from-bg-dark to-indigo-950 flex items-center justify-center px-4"> <div className="min-h-screen bg-gradient-to-br from-bg-dark to-slate-100 dark:to-indigo-950 flex items-center justify-center px-4">
<div className="w-full max-w-md"> <div className="w-full max-w-md">
{/* Brand */} {/* Brand */}
<div className="text-center mb-8"> <div className="text-center mb-8">
+28 -42
View File
@@ -1,46 +1,32 @@
export interface User { /**
id: number; * Application types derived from the auto-generated OpenAPI schema.
email: string; *
role: 'admin' | 'user'; * Run `npm run generate:local` (or `npm run generate` with the API running)
created_at: string; * to regenerate `src/api/schema.d.ts` from the backend OpenAPI spec.
} *
* These aliases keep the rest of the codebase stable while the source of
* truth lives in the generated file.
*/
export interface TokenResponse { import type { components } from '../api/schema';
access_token: string;
refresh_token: string;
token_type: string;
}
export interface CompanyAnalysis { // Re-export schema types under the names the rest of the app expects.
company_name: string; export type User = components['schemas']['UserResponse'];
analysis: string; export type TokenResponse = components['schemas']['TokenResponse'];
patent_count: number; export type CompanyAnalysis = components['schemas']['CompanyAnalysisResponse'];
success: boolean; export type BatchAnalysisResult = components['schemas']['BatchAnalysisResponse'];
error: string | null; export type JobStatus = components['schemas']['JobStatus'];
timestamp: string; export type Analytics = Omit<components['schemas']['AnalyticsResponse'], 'by_company' | 'by_type'> & {
}
export interface BatchAnalysisResult {
results: CompanyAnalysis[];
total_companies: number;
successful: number;
failed: number;
timestamp: string;
}
export interface JobStatus {
job_id: string;
status: 'pending' | 'running' | 'completed' | 'failed';
progress: number;
total_companies: number;
completed_companies: number;
result: BatchAnalysisResult | null;
error: string | null;
}
export interface Analytics {
total_messages: number;
by_company: Array<{ company_name: string; count: number }>; by_company: Array<{ company_name: string; count: number }>;
by_type: Array<{ analysis_type: string; count: number }>; by_type: Array<{ analysis_type: string; count: number }>;
period_days: number; };
}
// Additional generated types that may be useful elsewhere.
export type RegisterRequest = components['schemas']['RegisterRequest'];
export type LoginRequest = components['schemas']['LoginRequest'];
export type RefreshRequest = components['schemas']['RefreshRequest'];
export type UpdateRoleRequest = components['schemas']['UpdateRoleRequest'];
export type HealthResponse = components['schemas']['HealthResponse'];
export type BatchAnalysisRequest = components['schemas']['BatchAnalysisRequest'];
export type ValidationError = components['schemas']['ValidationError'];
export type HTTPValidationError = components['schemas']['HTTPValidationError'];
+7 -6
View File
@@ -4,6 +4,7 @@ export default {
"./index.html", "./index.html",
"./src/**/*.{js,ts,jsx,tsx}", "./src/**/*.{js,ts,jsx,tsx}",
], ],
darkMode: 'class',
theme: { theme: {
extend: { extend: {
colors: { colors: {
@@ -16,15 +17,15 @@ export default {
warning: '#f59e0b', warning: '#f59e0b',
error: '#ef4444', error: '#ef4444',
bg: { bg: {
dark: '#0f172a', dark: 'var(--color-bg-dark)',
card: '#1e293b', card: 'var(--color-bg-card)',
'card-hover': '#334155', 'card-hover': 'var(--color-bg-card-hover)',
}, },
text: { text: {
primary: '#f8fafc', primary: 'var(--color-text-primary)',
secondary: '#94a3b8', secondary: 'var(--color-text-secondary)',
}, },
border: '#334155', border: 'var(--color-border)',
}, },
}, },
}, },
+4
View File
@@ -14,3 +14,7 @@ numpy
pandas pandas
bcrypt bcrypt
PyJWT PyJWT
slowapi
apscheduler
boto3
reportlab
+8
View File
@@ -0,0 +1,8 @@
[lint]
select = ["E", "F", "I"]
ignore = [
"E501", # line too long (handled by formatter)
]
[lint.per-file-ignores]
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
+3
View File
@@ -40,6 +40,9 @@ def main():
print("\nTables created:") print("\nTables created:")
print(" - llm_messages: Stores all LLM prompts and responses") print(" - llm_messages: Stores all LLM prompts and responses")
print(" - users: Stores user accounts") print(" - users: Stores user accounts")
print(" - jobs: Stores async batch job state")
print(" - patents: Patent PDF cache")
print(" - serp_queries: SERP query result cache")
print("\nIndexes created:") print("\nIndexes created:")
print(" - idx_messages_timestamp: For time-based queries") print(" - idx_messages_timestamp: For time-based queries")
print(" - idx_messages_company: For company-specific queries") print(" - idx_messages_company: For company-specific queries")
+5 -3
View File
@@ -1,9 +1,11 @@
"""Tests for the high-level company analyzer orchestration.""" """Tests for the high-level company analyzer orchestration."""
from unittest.mock import MagicMock, Mock
import pytest import pytest
from unittest.mock import Mock, patch, call, MagicMock
from SPARC.analyzer import CompanyAnalyzer from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult from SPARC.types import BatchAnalysisResult, Patent, Patents
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
"""Test analyzer initialization with API key.""" """Test analyzer initialization with API key."""
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer") mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
analyzer = CompanyAnalyzer(openrouter_api_key="test-key") _analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
mock_llm.assert_called_once_with(api_key="test-key") mock_llm.assert_called_once_with(api_key="test-key")
+49 -4
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints.""" """Tests for FastAPI web service endpoints."""
import pytest
from datetime import datetime from datetime import datetime
from unittest.mock import Mock, patch from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from SPARC.api import app, _analyzer, _jobs from SPARC.api import app
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@pytest.fixture @pytest.fixture
@@ -181,3 +182,47 @@ class TestJobEndpoints:
"""Test listing jobs with status filter.""" """Test listing jobs with status filter."""
response = client.get("/jobs?status=completed") response = client.get("/jobs?status=completed")
assert response.status_code == 200 assert response.status_code == 200
class TestModelValidation:
"""Test that unsupported model identifiers are rejected."""
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with unsupported model returns 400."""
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
def test_analyze_accepts_supported_model(self, client, mock_analyzer):
"""GET /analyze/{company} with a supported model succeeds."""
mock_result = CompanyAnalysisResult(
company_name="nvidia",
analysis="test",
patent_count=1,
success=True,
timestamp=datetime.now(),
model="anthropic/claude-3.5-sonnet",
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
assert response.status_code == 200
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
"""POST /analyze/batch with unsupported model returns 400."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
)
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
def test_list_models_returns_supported(self, client):
"""GET /models returns the allow-list."""
response = client.get("/models")
assert response.status_code == 200
data = response.json()
assert "models" in data
assert "default" in data
assert len(data["models"]) > 0
assert all("id" in m and "name" in m and "provider" in m for m in data["models"])
+302
View File
@@ -0,0 +1,302 @@
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token, create_refresh_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth endpoints.
Returns a MagicMock with all DB methods pre-configured.
"""
db = MagicMock()
# Default: no users exist
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _make_admin_user():
return {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _make_regular_user():
return {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token for the given user."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
class TestRegister:
"""POST /auth/register"""
def test_register_first_user_becomes_admin(self, client, mock_db):
"""First registered user should get admin role."""
mock_db.get_user_count.return_value = 0
mock_db.create_user.return_value = {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/register",
json={"email": "admin@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "admin@test.com"
assert data["role"] == "admin"
mock_db.create_user.assert_called_once_with(
email="admin@test.com", password="securepass123", role="admin"
)
def test_register_subsequent_user_gets_user_role(self, client, mock_db):
"""Non-first user should get regular user role."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = _make_regular_user()
response = client.post(
"/auth/register",
json={"email": "user@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["role"] == "user"
def test_register_duplicate_email_returns_400(self, client, mock_db):
"""Registering with an existing email should return 400."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # indicates duplicate
response = client.post(
"/auth/register",
json={"email": "existing@test.com", "password": "securepass123"},
)
assert response.status_code == 400
assert "already registered" in response.json()["detail"].lower()
class TestLogin:
"""POST /auth/login"""
def test_login_valid_credentials_returns_tokens(self, client, mock_db):
"""Valid credentials should return access and refresh tokens."""
user = _make_regular_user()
mock_db.authenticate_user.return_value = user
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "correctpassword"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
def test_login_invalid_credentials_returns_401(self, client, mock_db):
"""Invalid credentials should return 401."""
mock_db.authenticate_user.return_value = None
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "wrongpassword"},
)
assert response.status_code == 401
assert "invalid" in response.json()["detail"].lower()
class TestGetMe:
"""GET /auth/me"""
def test_valid_access_token_returns_user(self, client, mock_db):
"""A valid access token should return the user's data."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert data["email"] == "user@test.com"
assert data["id"] == 2
def test_missing_token_returns_401(self, client):
"""No token should return 401 (403 from HTTPBearer)."""
response = client.get("/auth/me")
assert response.status_code in (401, 403)
def test_expired_token_returns_401(self, client, mock_db):
"""An expired token should return 401."""
# Create a token that has already expired
from datetime import timedelta
import jwt as pyjwt
from SPARC.auth import JWT_ALGORITHM, JWT_SECRET
payload = {
"sub": "1",
"email": "user@test.com",
"role": "user",
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
"type": "access",
}
expired_token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
)
assert response.status_code == 401
def test_refresh_token_as_access_returns_401(self, client, mock_db):
"""Using a refresh token as an access token should return 401."""
user = _make_regular_user()
refresh_token = create_refresh_token(user["id"], user["email"], user["role"])
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}
)
assert response.status_code == 401
class TestRefreshToken:
"""POST /auth/refresh"""
def test_valid_refresh_token_returns_new_tokens(self, client, mock_db):
"""A valid refresh token should issue new access and refresh tokens."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
refresh = create_refresh_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": refresh}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
def test_invalid_refresh_token_returns_401(self, client, mock_db):
"""An invalid refresh token should return 401."""
response = client.post(
"/auth/refresh", json={"refresh_token": "invalid-token-string"}
)
assert response.status_code == 401
def test_access_token_as_refresh_returns_401(self, client, mock_db):
"""Using an access token as a refresh token should return 401."""
user = _make_regular_user()
access = create_access_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": access}
)
assert response.status_code == 401
class TestAdminUsers:
"""GET /admin/users and PATCH /admin/users/{id}/role"""
def test_admin_can_list_users(self, client, mock_db):
"""Admin token should allow listing users."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.get_all_users.return_value = [admin, _make_regular_user()]
response = client.get("/admin/users", headers=_auth_header(admin))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_regular_user_cannot_list_users(self, client, mock_db):
"""Regular user token should be rejected with 403."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/admin/users", headers=_auth_header(user))
assert response.status_code == 403
def test_no_token_cannot_list_users(self, client):
"""No token should be rejected."""
response = client.get("/admin/users")
assert response.status_code in (401, 403)
def test_admin_can_change_user_role(self, client, mock_db):
"""Admin should be able to change another user's role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.update_user_role.return_value = {
"id": 2,
"email": "user@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.patch(
"/admin/users/2/role",
json={"role": "admin"},
headers=_auth_header(admin),
)
assert response.status_code == 200
assert response.json()["role"] == "admin"
def test_admin_cannot_change_own_role(self, client, mock_db):
"""Admin should not be able to change their own role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
response = client.patch(
"/admin/users/1/role",
json={"role": "user"},
headers=_auth_header(admin),
)
assert response.status_code == 400
assert "own role" in response.json()["detail"].lower()
+3 -1
View File
@@ -1,7 +1,9 @@
"""Tests for LLM analysis functionality.""" """Tests for LLM analysis functionality."""
from unittest.mock import Mock
import pytest import pytest
from unittest.mock import Mock, MagicMock, patch
from SPARC.llm import LLMAnalyzer from SPARC.llm import LLMAnalyzer
+97
View File
@@ -0,0 +1,97 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break
+116
View File
@@ -0,0 +1,116 @@
"""Tests for security hardening: JWT secret startup check, CORS config, credential handling."""
import os
from unittest.mock import patch
import pytest
class TestJWTSecretStartupCheck:
"""Test the startup guard that refuses default JWT secret in non-dev environments."""
def test_default_secret_in_production_raises(self):
"""Starting with default secret and APP_ENV=production must raise RuntimeError."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
# Reload config to pick up the new APP_ENV
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
# Patch JWT_SECRET to the default
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
with pytest.raises(RuntimeError, match="FATAL.*JWT_SECRET"):
check_jwt_secret()
# Restore config
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_in_development_succeeds(self):
"""Starting with default secret and APP_ENV=development must not raise."""
with patch.dict(os.environ, {"APP_ENV": "development"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise
check_jwt_secret()
# Restore
importlib.reload(SPARC.config)
def test_custom_secret_in_production_succeeds(self):
"""Starting with a custom secret in production must not raise."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", "my-secure-random-secret-abc123"):
# Should not raise
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_unset_env_succeeds(self):
"""When APP_ENV is unset (defaults to development), default secret is allowed."""
with patch.dict(os.environ, {}, clear=False):
# Remove APP_ENV if present
env = os.environ.copy()
env.pop("APP_ENV", None)
with patch.dict(os.environ, env, clear=True):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise (defaults to development)
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
class TestCORSConfig:
"""Test that CORS origins are configurable via environment variable."""
def test_default_cors_origins(self):
"""When CORS_ORIGINS is unset, defaults to localhost origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"http://localhost:3000",
"http://localhost:5173",
]
def test_custom_cors_origins(self):
"""Setting CORS_ORIGINS configures allowed origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com,https://app.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"https://sparc.example.com",
"https://app.example.com",
]
# Restore
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
def test_single_cors_origin(self):
"""A single origin without comma works correctly."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == ["https://sparc.example.com"]
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
+2 -3
View File
@@ -1,9 +1,8 @@
"""Tests for SERP API patent retrieval and parsing functionality.""" """Tests for SERP API patent retrieval and parsing functionality."""
import os
import pytest
from unittest.mock import patch, Mock
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import Mock
from SPARC.serp_api import SERP from SPARC.serp_api import SERP
from SPARC.types import Patent from SPARC.types import Patent