Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company c317632edb ci: add pytest and ruff linting to CI, fix all lint errors
- Add test job to build.yaml that runs pytest and ruff before building images
- Add standalone test.yaml workflow for PRs
- Add ruff.toml with E/F/I rules configured
- Fix all ruff lint errors: sort imports, remove unused imports, fix re-exports
- Build jobs now depend on test job passing (needs: test)

Closes leeworks-agents/SPARC#18
Closes leeworks-agents/SPARC#19

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:04:24 +00:00
7 changed files with 205 additions and 281 deletions
+25 -40
View File
@@ -5,13 +5,10 @@ to provide company performance estimation based on patent portfolios.
"""
import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable
from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient
from SPARC.llm import LLMAnalyzer
from SPARC.serp_api import SERP
@@ -55,13 +52,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None:
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="")
for pid in cached_ids
])
else:
logger.info("Retrieving patents for %s...", company_name)
print(f"Retrieving patents for {company_name}...")
patents = SERP.query(company_name)
# Cache the SERP results
if patents.patents:
@@ -69,13 +66,12 @@ class CompanyAnalyzer:
company_name=company_name,
query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
)
if not patents.patents:
return f"No patents found for {company_name}"
logger.info("Found %d patents. Processing...", len(patents.patents))
print(f"Found {len(patents.patents)} patents. Processing...")
# Download, parse, and minimize patents in parallel
processed_patents = []
@@ -91,12 +87,12 @@ class CompanyAnalyzer:
if result:
processed_patents.append(result)
except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e)
print(f"Warning: Failed to process {patent.patent_id}: {e}")
if not processed_patents:
return f"Failed to process any patents for {company_name}"
logger.info("Analyzing portfolio with LLM...")
print("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -108,10 +104,12 @@ class CompanyAnalyzer:
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
"""Analyze a single patent by ID.
If the patent PDF is not already on disk, this method attempts to
download it automatically by looking up the PDF link in the database
cache. If the link is not cached either, a ``FileNotFoundError`` is
raised with instructions on how to obtain the PDF.
Prerequisite:
The patent PDF must already exist at ``patents/{patent_id}.pdf``
before calling this method. PDFs are downloaded automatically when
using the batch analysis pipeline (``analyze_company`` or the
``/analyze/batch`` API endpoint). For standalone usage, download
the PDF manually or call ``SERP.save_patents()`` first.
Args:
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
@@ -121,29 +119,16 @@ class CompanyAnalyzer:
Analysis of the specific patent's innovation quality
Raises:
FileNotFoundError: If the patent PDF cannot be found or downloaded.
FileNotFoundError: If the patent PDF is not found at the expected path.
"""
import os
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
patent_path = f"patents/{patent_id}.pdf"
if not os.path.exists(patent_path):
# Attempt to download the PDF automatically from cached metadata
cached = self.db.get_cached_patent(patent_id)
pdf_link = cached.get("pdf_link") if cached else None
if pdf_link:
logger.info("PDF not on disk; downloading %s from cached link", patent_id)
patent = SERP.save_patents(
Patent(patent_id=patent_id, pdf_link=pdf_link)
)
patent_path = patent.pdf_path
else:
raise FileNotFoundError(
f"Patent PDF not found at '{patent_path}' and no download link is "
f"cached for '{patent_id}'. Run a company analysis first to populate "
f"the cache, or call SERP.save_patents() with the patent's PDF link."
f"Patent PDF not found at '{patent_path}'. "
f"Download the PDF first using SERP.save_patents() or the batch analysis pipeline."
)
try:
@@ -198,7 +183,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e)
print(f"Warning: Failed to process {patent.patent_id}: {e}")
return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -269,7 +254,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
logger.info("Starting batch analysis of %d companies...", total)
print(f"Starting batch analysis of {total} companies...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = {
@@ -286,8 +271,8 @@ class CompanyAnalyzer:
result = future.result()
results.append(result)
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", completed, total, status, company)
status = "" if result.success else ""
print(f"[{completed}/{total}] {status} {company}")
if progress_callback:
progress_callback(company, completed, total)
@@ -302,12 +287,12 @@ class CompanyAnalyzer:
error=str(e),
)
)
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
print(f"[{completed}/{total}] ✗ {company}: {e}")
successful = sum(1 for r in results if r.success)
failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult(
results=results,
@@ -333,20 +318,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
logger.info("Starting sequential analysis of %d companies...", total)
print(f"Starting sequential analysis of {total} companies...")
for idx, company in enumerate(companies, 1):
logger.info("[%d/%d] Analyzing %s...", idx, total, company)
print(f"\n[{idx}/{total}] Analyzing {company}...")
result = self._analyze_company_safe(company)
results.append(result)
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", idx, total, status, company)
status = "" if result.success else ""
print(f"[{idx}/{total}] {status} {company}")
successful = sum(1 for r in results if r.success)
failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult(
results=results,
+1 -37
View File
@@ -21,13 +21,11 @@ from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
close_db_client,
create_tokens,
decode_token,
get_current_admin,
get_current_user,
get_db_client,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -157,7 +155,6 @@ async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
@@ -170,9 +167,8 @@ async def lifespan(app: FastAPI):
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup
# Cleanup if needed
_analyzer = None
close_db_client()
app = FastAPI(
@@ -429,38 +425,6 @@ async def analyze_company(
return _convert_result(result)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: str = Query(description="Company name for analysis context"),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
+4 -29
View File
@@ -146,36 +146,11 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient:
"""Get the shared pooled database client for auth operations.
Returns the module-level singleton DatabaseClient. If not yet initialized
(e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
"""Get database client for auth operations."""
client = DatabaseClient(config.database_url)
client.connect()
return client
async def get_current_user(
-14
View File
@@ -2,20 +2,12 @@
Loads environment variables from .env file for API keys and other secrets.
"""
import logging
import os
from dotenv import load_dotenv
load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search
api_key = os.getenv("API_KEY")
@@ -39,12 +31,6 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "")
+41 -28
View File
@@ -222,6 +222,8 @@ class DatabaseClient:
Returns:
Cached message dict if found, None otherwise
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
query = """
@@ -244,8 +246,7 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1"
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
result = cursor.fetchone()
return dict(result) if result else None
@@ -276,10 +277,11 @@ class DatabaseClient:
Returns:
The ID of the inserted record
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
with self.get_conn() as conn:
with conn.cursor() as cursor:
with self.conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO llm_messages
@@ -301,7 +303,7 @@ class DatabaseClient:
)
message_id = cursor.fetchone()[0]
conn.commit()
self.conn.commit()
return message_id
@@ -323,6 +325,8 @@ class DatabaseClient:
Returns:
List of message dictionaries
"""
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1"
params = []
@@ -337,8 +341,7 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset])
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
@@ -351,8 +354,9 @@ class DatabaseClient:
Returns:
Dictionary with analytics data
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
@@ -647,11 +651,12 @@ class DatabaseClient:
Returns:
Created user dict or None if email exists
"""
self.connect()
password_hash = self.hash_password(password)
try:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO users (email, password_hash, role)
@@ -661,9 +666,10 @@ class DatabaseClient:
(email, password_hash, role),
)
user = cursor.fetchone()
conn.commit()
self.conn.commit()
return dict(user) if user else None
except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -676,8 +682,9 @@ class DatabaseClient:
Returns:
User dict if authenticated, None otherwise
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
@@ -702,8 +709,9 @@ class DatabaseClient:
Returns:
User dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,),
@@ -720,8 +728,9 @@ class DatabaseClient:
Returns:
User dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,),
@@ -739,8 +748,9 @@ class DatabaseClient:
Returns:
List of user dicts
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, email, role, created_at
@@ -762,8 +772,9 @@ class DatabaseClient:
Returns:
Updated user dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
UPDATE users
@@ -774,7 +785,7 @@ class DatabaseClient:
(role, user_id),
)
user = cursor.fetchone()
conn.commit()
self.conn.commit()
return dict(user) if user else None
def delete_user(self, user_id: int) -> bool:
@@ -786,11 +797,12 @@ class DatabaseClient:
Returns:
True if deleted
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
self.connect()
with self.conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0
conn.commit()
self.conn.commit()
return deleted
def get_user_count(self) -> int:
@@ -799,7 +811,8 @@ class DatabaseClient:
Returns:
Number of users
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
self.connect()
with self.conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
+7 -6
View File
@@ -1,6 +1,5 @@
"""LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI
@@ -8,8 +7,6 @@ from openai import OpenAI
from SPARC import config
from SPARC.database import DatabaseClient
logger = logging.getLogger(__name__)
class LLMAnalyzer:
"""Handles LLM-based analysis of patent content."""
@@ -25,7 +22,7 @@ class LLMAnalyzer:
"""
self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = config.model
self.model = "anthropic/claude-3.5-sonnet"
# Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url)
@@ -64,7 +61,11 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode:
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
print("=" * 80)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
return "[TEST MODE - No API call made]"
# Check cache first
@@ -166,7 +167,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode:
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
print(prompt)
return "[TEST MODE]"
metadata = {
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass
class Patent:
patent_id: str
patent_id: int
pdf_link: str
pdf_path: str | None = None
summary: dict | None = None