Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company c317632edb ci: add pytest and ruff linting to CI, fix all lint errors
- Add test job to build.yaml that runs pytest and ruff before building images
- Add standalone test.yaml workflow for PRs
- Add ruff.toml with E/F/I rules configured
- Fix all ruff lint errors: sort imports, remove unused imports, fix re-exports
- Build jobs now depend on test job passing (needs: test)

Closes leeworks-agents/SPARC#18
Closes leeworks-agents/SPARC#19

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:04:24 +00:00
9 changed files with 193 additions and 501 deletions
+16 -21
View File
@@ -5,13 +5,10 @@ to provide company performance estimation based on patent portfolios.
""" """
import hashlib import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable from typing import Callable
from SPARC import config from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from SPARC.llm import LLMAnalyzer from SPARC.llm import LLMAnalyzer
from SPARC.serp_api import SERP from SPARC.serp_api import SERP
@@ -55,13 +52,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash) cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None: if cached_ids is not None:
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids)) print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
patents = Patents(patents=[ patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="") Patent(patent_id=pid, pdf_link="")
for pid in cached_ids for pid in cached_ids
]) ])
else: else:
logger.info("Retrieving patents for %s...", company_name) print(f"Retrieving patents for {company_name}...")
patents = SERP.query(company_name) patents = SERP.query(company_name)
# Cache the SERP results # Cache the SERP results
if patents.patents: if patents.patents:
@@ -69,13 +66,12 @@ class CompanyAnalyzer:
company_name=company_name, company_name=company_name,
query_hash=query_hash, query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents], patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
) )
if not patents.patents: if not patents.patents:
return f"No patents found for {company_name}" return f"No patents found for {company_name}"
logger.info("Found %d patents. Processing...", len(patents.patents)) print(f"Found {len(patents.patents)} patents. Processing...")
# Download, parse, and minimize patents in parallel # Download, parse, and minimize patents in parallel
processed_patents = [] processed_patents = []
@@ -91,12 +87,12 @@ class CompanyAnalyzer:
if result: if result:
processed_patents.append(result) processed_patents.append(result)
except Exception as e: except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e) print(f"Warning: Failed to process {patent.patent_id}: {e}")
if not processed_patents: if not processed_patents:
return f"Failed to process any patents for {company_name}" return f"Failed to process any patents for {company_name}"
logger.info("Analyzing portfolio with LLM...") print("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio( analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -126,7 +122,6 @@ class CompanyAnalyzer:
FileNotFoundError: If the patent PDF is not found at the expected path. FileNotFoundError: If the patent PDF is not found at the expected path.
""" """
import os import os
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
patent_path = f"patents/{patent_id}.pdf" patent_path = f"patents/{patent_id}.pdf"
@@ -188,7 +183,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content} return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e: except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e) print(f"Warning: Failed to process {patent.patent_id}: {e}")
return None return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -259,7 +254,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
logger.info("Starting batch analysis of %d companies...", total) print(f"Starting batch analysis of {total} companies...")
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = { future_to_company = {
@@ -276,8 +271,8 @@ class CompanyAnalyzer:
result = future.result() result = future.result()
results.append(result) results.append(result)
status = "OK" if result.success else "FAIL" status = "" if result.success else ""
logger.info("[%d/%d] %s %s", completed, total, status, company) print(f"[{completed}/{total}] {status} {company}")
if progress_callback: if progress_callback:
progress_callback(company, completed, total) progress_callback(company, completed, total)
@@ -292,12 +287,12 @@ class CompanyAnalyzer:
error=str(e), error=str(e),
) )
) )
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e) print(f"[{completed}/{total}] ✗ {company}: {e}")
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed) print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
@@ -323,20 +318,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
logger.info("Starting sequential analysis of %d companies...", total) print(f"Starting sequential analysis of {total} companies...")
for idx, company in enumerate(companies, 1): for idx, company in enumerate(companies, 1):
logger.info("[%d/%d] Analyzing %s...", idx, total, company) print(f"\n[{idx}/{total}] Analyzing {company}...")
result = self._analyze_company_safe(company) result = self._analyze_company_safe(company)
results.append(result) results.append(result)
status = "OK" if result.success else "FAIL" status = "" if result.success else ""
logger.info("[%d/%d] %s %s", idx, total, status, company) print(f"[{idx}/{total}] {status} {company}")
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed) print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
+1 -62
View File
@@ -21,13 +21,11 @@ from SPARC.auth import (
TokenResponse, TokenResponse,
UserResponse, UserResponse,
check_jwt_secret, check_jwt_secret,
close_db_client,
create_tokens, create_tokens,
decode_token, decode_token,
get_current_admin, get_current_admin,
get_current_user, get_current_user,
get_db_client, get_db_client,
init_db_client,
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -157,7 +155,6 @@ async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown.""" """Initialize resources on startup, clean up on shutdown."""
global _analyzer global _analyzer
check_jwt_secret() check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer() _analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed # Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
@@ -169,13 +166,9 @@ async def lifespan(app: FastAPI):
import logging import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale) logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close() _db.close()
# Start scheduled analysis if tracked companies are configured
from SPARC.scheduler import start_scheduler
start_scheduler()
yield yield
# Cleanup # Cleanup if needed
_analyzer = None _analyzer = None
close_db_client()
app = FastAPI( app = FastAPI(
@@ -372,60 +365,6 @@ async def delete_user(
return {"message": "User deleted"} return {"message": "User deleted"}
# ============== Tracked Companies Endpoints ==============
class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255)
@app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin),
):
"""List all tracked companies (admin only)."""
db = get_db_client()
return db.list_tracked_companies()
@app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company(
request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin),
):
"""Add a company to the tracked list (admin only)."""
db = get_db_client()
result = db.add_tracked_company(request.company_name)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company(
company_name: str,
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
db = get_db_client()
removed = db.remove_tracked_company(company_name)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/alerts", tags=["Admin"])
async def list_alerts(
limit: int = Query(default=50, ge=1, le=200),
_: UserResponse = Depends(get_current_admin),
):
"""List recent alerts from scheduled analysis (admin only)."""
db = get_db_client()
return db.list_alerts(limit=limit)
# ============== Analytics Endpoint ============== # ============== Analytics Endpoint ==============
+4 -29
View File
@@ -146,36 +146,11 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient: def get_db_client() -> DatabaseClient:
"""Get the shared pooled database client for auth operations. """Get database client for auth operations."""
client = DatabaseClient(config.database_url)
Returns the module-level singleton DatabaseClient. If not yet initialized client.connect()
(e.g., during tests), creates a new instance as a fallback. return client
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user( async def get_current_user(
-14
View File
@@ -2,20 +2,12 @@
Loads environment variables from .env file for API keys and other secrets. Loads environment variables from .env file for API keys and other secrets.
""" """
import logging
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search # SerpAPI key for patent search
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
@@ -39,12 +31,6 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90")) patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5")) patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/) # Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy # This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "") root_path = os.getenv("ROOT_PATH", "")
+164 -258
View File
@@ -192,35 +192,6 @@ class DatabaseClient:
ON jobs(status) ON jobs(status)
""") """)
# Create tracked companies table for scheduled analysis
cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL,
last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create alerts table for significant changes
cursor.execute("""
CREATE TABLE IF NOT EXISTS alerts (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) NOT NULL,
alert_type VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
old_value NUMERIC,
new_value NUMERIC,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_alerts_company
ON alerts(company_name)
""")
self.conn.commit() self.conn.commit()
@staticmethod @staticmethod
@@ -251,6 +222,8 @@ class DatabaseClient:
Returns: Returns:
Cached message dict if found, None otherwise Cached message dict if found, None otherwise
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
query = """ query = """
@@ -273,11 +246,10 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1" query += " ORDER BY timestamp DESC LIMIT 1"
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params)
cursor.execute(query, params) result = cursor.fetchone()
result = cursor.fetchone() return dict(result) if result else None
return dict(result) if result else None
def store_message( def store_message(
self, self,
@@ -305,32 +277,33 @@ class DatabaseClient:
Returns: Returns:
The ID of the inserted record The ID of the inserted record
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
with self.get_conn() as conn: with self.conn.cursor() as cursor:
with conn.cursor() as cursor: cursor.execute(
cursor.execute( """
""" INSERT INTO llm_messages
INSERT INTO llm_messages (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id
RETURNING id """,
""", (
( prompt,
prompt, prompt_hash,
prompt_hash, response,
response, company_name,
company_name, analysis_type,
analysis_type, model,
model, json.dumps(metadata) if metadata else None,
json.dumps(metadata) if metadata else None, json.dumps(token_usage) if token_usage else None,
json.dumps(token_usage) if token_usage else None, is_cached,
is_cached, ),
), )
)
message_id = cursor.fetchone()[0] message_id = cursor.fetchone()[0]
conn.commit() self.conn.commit()
return message_id return message_id
@@ -352,6 +325,8 @@ class DatabaseClient:
Returns: Returns:
List of message dictionaries List of message dictionaries
""" """
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
@@ -366,10 +341,9 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset]) params.extend([limit, offset])
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params)
cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()]
return [dict(row) for row in cursor.fetchall()]
def get_analytics(self, days: int = 30) -> Dict: def get_analytics(self, days: int = 30) -> Dict:
"""Get analytics on message usage. """Get analytics on message usage.
@@ -380,52 +354,53 @@ class DatabaseClient:
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
""",
(days,),
)
total = cursor.fetchone()["total_messages"]
# Messages by company with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( # Total messages
""" cursor.execute(
SELECT company_name, COUNT(*) as count """
FROM llm_messages SELECT COUNT(*) as total_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' FROM llm_messages
GROUP BY company_name WHERE timestamp >= NOW() - INTERVAL '%s days'
ORDER BY count DESC """,
LIMIT 10 (days,),
""", )
(days,), total = cursor.fetchone()["total_messages"]
)
by_company = cursor.fetchall()
# Messages by type # Messages by company
cursor.execute( cursor.execute(
""" """
SELECT analysis_type, COUNT(*) as count SELECT company_name, COUNT(*) as count
FROM llm_messages FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days' WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type GROUP BY company_name
ORDER BY count DESC ORDER BY count DESC
""", LIMIT 10
(days,), """,
) (days,),
by_type = cursor.fetchall() )
by_company = cursor.fetchall()
return { # Messages by type
"total_messages": total, cursor.execute(
"by_company": [dict(row) for row in by_company], """
"by_type": [dict(row) for row in by_type], SELECT analysis_type, COUNT(*) as count
"period_days": days, FROM llm_messages
} WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
)
by_type = cursor.fetchall()
return {
"total_messages": total,
"by_company": [dict(row) for row in by_company],
"by_type": [dict(row) for row in by_type],
"period_days": days,
}
# Patent Cache Methods # Patent Cache Methods
@@ -676,23 +651,25 @@ class DatabaseClient:
Returns: Returns:
Created user dict or None if email exists Created user dict or None if email exists
""" """
self.connect()
password_hash = self.hash_password(password) password_hash = self.hash_password(password)
try: try:
with self.get_conn() as conn: with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(
cursor.execute( """
""" INSERT INTO users (email, password_hash, role)
INSERT INTO users (email, password_hash, role) VALUES (%s, %s, %s)
VALUES (%s, %s, %s) RETURNING id, email, role, created_at
RETURNING id, email, role, created_at """,
""", (email, password_hash, role),
(email, password_hash, role), )
) user = cursor.fetchone()
user = cursor.fetchone() self.conn.commit()
conn.commit()
return dict(user) if user else None return dict(user) if user else None
except psycopg2.errors.UniqueViolation: except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]: def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -705,22 +682,23 @@ class DatabaseClient:
Returns: Returns:
User dict if authenticated, None otherwise User dict if authenticated, None otherwise
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
if user and self.verify_password(password, user["password_hash"]): with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
return { cursor.execute(
"id": user["id"], "SELECT * FROM users WHERE email = %s",
"email": user["email"], (email,),
"role": user["role"], )
"created_at": user["created_at"], user = cursor.fetchone()
}
return None if user and self.verify_password(password, user["password_hash"]):
return {
"id": user["id"],
"email": user["email"],
"role": user["role"],
"created_at": user["created_at"],
}
return None
def get_user_by_id(self, user_id: int) -> Optional[Dict]: def get_user_by_id(self, user_id: int) -> Optional[Dict]:
"""Get a user by ID. """Get a user by ID.
@@ -731,14 +709,15 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
"SELECT id, email, role, created_at FROM users WHERE id = %s", cursor.execute(
(user_id,), "SELECT id, email, role, created_at FROM users WHERE id = %s",
) (user_id,),
user = cursor.fetchone() )
return dict(user) if user else None user = cursor.fetchone()
return dict(user) if user else None
def get_user_by_email(self, email: str) -> Optional[Dict]: def get_user_by_email(self, email: str) -> Optional[Dict]:
"""Get a user by email. """Get a user by email.
@@ -749,14 +728,15 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
"SELECT id, email, role, created_at FROM users WHERE email = %s", cursor.execute(
(email,), "SELECT id, email, role, created_at FROM users WHERE email = %s",
) (email,),
user = cursor.fetchone() )
return dict(user) if user else None user = cursor.fetchone()
return dict(user) if user else None
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]: def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
"""Get all users (admin only). """Get all users (admin only).
@@ -768,18 +748,19 @@ class DatabaseClient:
Returns: Returns:
List of user dicts List of user dicts
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
""" cursor.execute(
SELECT id, email, role, created_at """
FROM users SELECT id, email, role, created_at
ORDER BY created_at DESC FROM users
LIMIT %s OFFSET %s ORDER BY created_at DESC
""", LIMIT %s OFFSET %s
(limit, offset), """,
) (limit, offset),
return [dict(row) for row in cursor.fetchall()] )
return [dict(row) for row in cursor.fetchall()]
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]: def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
"""Update a user's role (admin only). """Update a user's role (admin only).
@@ -791,19 +772,20 @@ class DatabaseClient:
Returns: Returns:
Updated user dict or None Updated user dict or None
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
""" cursor.execute(
UPDATE users """
SET role = %s, updated_at = CURRENT_TIMESTAMP UPDATE users
WHERE id = %s SET role = %s, updated_at = CURRENT_TIMESTAMP
RETURNING id, email, role, created_at WHERE id = %s
""", RETURNING id, email, role, created_at
(role, user_id), """,
) (role, user_id),
user = cursor.fetchone() )
conn.commit() user = cursor.fetchone()
self.conn.commit()
return dict(user) if user else None return dict(user) if user else None
def delete_user(self, user_id: int) -> bool: def delete_user(self, user_id: int) -> bool:
@@ -815,11 +797,12 @@ class DatabaseClient:
Returns: Returns:
True if deleted True if deleted
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) with self.conn.cursor() as cursor:
deleted = cursor.rowcount > 0 cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
conn.commit() deleted = cursor.rowcount > 0
self.conn.commit()
return deleted return deleted
def get_user_count(self) -> int: def get_user_count(self) -> int:
@@ -828,85 +811,8 @@ class DatabaseClient:
Returns: Returns:
Number of users Number of users
""" """
with self.get_conn() as conn: self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
# Tracked Companies Methods with self.conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
def add_tracked_company(self, company_name: str) -> Optional[Dict]: return cursor.fetchone()[0]
"""Add a company to the tracking list."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try:
cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *",
(company_name,),
)
row = cursor.fetchone()
conn.commit()
return dict(row) if row else None
except Exception:
conn.rollback()
return None
def remove_tracked_company(self, company_name: str) -> bool:
"""Remove a company from the tracking list."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
conn.commit()
return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]:
"""List all tracked companies."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name")
return [dict(row) for row in cursor.fetchall()]
def update_tracked_company(
self, company_name: str, patent_count: int
) -> None:
"""Update the last analysis stats for a tracked company."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""UPDATE tracked_companies
SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP
WHERE LOWER(company_name) = LOWER(%s)""",
(patent_count, company_name),
)
conn.commit()
def store_alert(
self,
company_name: str,
alert_type: str,
message: str,
old_value: float | None = None,
new_value: float | None = None,
) -> None:
"""Record an alert for a significant change."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""INSERT INTO alerts (company_name, alert_type, message, old_value, new_value)
VALUES (%s, %s, %s, %s, %s)""",
(company_name, alert_type, message, old_value, new_value),
)
conn.commit()
def list_alerts(self, limit: int = 50) -> List[Dict]:
"""List recent alerts."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s",
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
+7 -6
View File
@@ -1,6 +1,5 @@
"""LLM integration for patent analysis using OpenRouter.""" """LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict from typing import Dict
from openai import OpenAI from openai import OpenAI
@@ -8,8 +7,6 @@ from openai import OpenAI
from SPARC import config from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
logger = logging.getLogger(__name__)
class LLMAnalyzer: class LLMAnalyzer:
"""Handles LLM-based analysis of patent content.""" """Handles LLM-based analysis of patent content."""
@@ -25,7 +22,7 @@ class LLMAnalyzer:
""" """
self.test_mode = test_mode self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = config.model self.model = "anthropic/claude-3.5-sonnet"
# Always initialize database client for storage and caching # Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url) self.db_client = DatabaseClient(config.database_url)
@@ -64,7 +61,11 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage.""" Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode: if self.test_mode:
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt) print("=" * 80)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
return "[TEST MODE - No API call made]" return "[TEST MODE - No API call made]"
# Check cache first # Check cache first
@@ -166,7 +167,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook.""" Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode: if self.test_mode:
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt) print(prompt)
return "[TEST MODE]" return "[TEST MODE]"
metadata = { metadata = {
-109
View File
@@ -1,109 +0,0 @@
"""Scheduled patent analysis for tracked companies.
Uses APScheduler to periodically re-analyze tracked companies and
detect significant changes in patent counts.
"""
import logging
import os
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.database import DatabaseClient
logger = logging.getLogger(__name__)
# Configurable via environment variable (in hours, default 24)
SCHEDULE_INTERVAL_HOURS = int(os.getenv("SCHEDULE_INTERVAL_HOURS", "24"))
# Patent count change threshold (percentage) to trigger an alert
CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20"))
def run_scheduled_analysis() -> None:
"""Re-analyze all tracked companies and check for significant changes."""
db = DatabaseClient(config.database_url)
db.connect()
db.initialize_schema()
tracked = db.list_tracked_companies()
if not tracked:
logger.info("No tracked companies configured; skipping scheduled analysis")
return
logger.info("Running scheduled analysis for %d tracked companies", len(tracked))
analyzer = CompanyAnalyzer(db_client=db)
for company_row in tracked:
name = company_row["company_name"]
old_count = company_row.get("last_patent_count", 0) or 0
try:
result = analyzer._analyze_company_safe(name)
if result.success:
new_count = result.patent_count
# Update tracking record
db.update_tracked_company(name, new_count)
# Check for significant change
if old_count > 0:
delta_pct = abs(new_count - old_count) / old_count * 100
if delta_pct >= CHANGE_THRESHOLD_PERCENT:
direction = "increased" if new_count > old_count else "decreased"
message = (
f"Patent count for {name} {direction} by {delta_pct:.0f}% "
f"({old_count} -> {new_count})"
)
logger.warning("ALERT: %s", message)
db.store_alert(
company_name=name,
alert_type="patent_count_change",
message=message,
old_value=old_count,
new_value=new_count,
)
elif new_count > 0:
# First analysis -- record baseline
logger.info("Baseline for %s: %d patents", name, new_count)
else:
logger.warning("Scheduled analysis failed for %s: %s", name, result.error)
except Exception as e:
logger.error("Error analyzing tracked company %s: %s", name, e)
db.close()
logger.info("Scheduled analysis complete")
def start_scheduler() -> None:
"""Start the APScheduler background scheduler.
Safe to call at application startup. If apscheduler is not installed,
the function logs a warning and returns without starting anything.
"""
try:
from apscheduler.schedulers.background import BackgroundScheduler
except ImportError:
logger.warning(
"apscheduler not installed; scheduled analysis disabled. "
"Install with: pip install apscheduler"
)
return
scheduler = BackgroundScheduler()
scheduler.add_job(
run_scheduled_analysis,
"interval",
hours=SCHEDULE_INTERVAL_HOURS,
id="scheduled_patent_analysis",
replace_existing=True,
)
scheduler.start()
logger.info(
"Scheduled patent analysis started (every %d hours, threshold %d%%)",
SCHEDULE_INTERVAL_HOURS,
CHANGE_THRESHOLD_PERCENT,
)
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass @dataclass
class Patent: class Patent:
patent_id: str patent_id: int
pdf_link: str pdf_link: str
pdf_path: str | None = None pdf_path: str | None = None
summary: dict | None = None summary: dict | None = None
-1
View File
@@ -15,4 +15,3 @@ pandas
bcrypt bcrypt
PyJWT PyJWT
slowapi slowapi
apscheduler