diff --git a/SPARC/api.py b/SPARC/api.py index a78c132..63e7838 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -169,6 +169,9 @@ async def lifespan(app: FastAPI): import logging logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale) _db.close() + # Start scheduled analysis if tracked companies are configured + from SPARC.scheduler import start_scheduler + start_scheduler() yield # Cleanup _analyzer = None @@ -369,6 +372,60 @@ async def delete_user( return {"message": "User deleted"} +# ============== Tracked Companies Endpoints ============== + + +class TrackCompanyRequest(BaseModel): + """Request to add a company to tracking.""" + + company_name: str = Field(..., min_length=1, max_length=255) + + +@app.get("/admin/tracked", tags=["Admin"]) +async def list_tracked_companies( + _: UserResponse = Depends(get_current_admin), +): + """List all tracked companies (admin only).""" + db = get_db_client() + return db.list_tracked_companies() + + +@app.post("/admin/tracked", tags=["Admin"]) +async def add_tracked_company( + request: TrackCompanyRequest, + _: UserResponse = Depends(get_current_admin), +): + """Add a company to the tracked list (admin only).""" + db = get_db_client() + result = db.add_tracked_company(request.company_name) + if not result: + raise HTTPException(status_code=409, detail="Company already tracked") + return result + + +@app.delete("/admin/tracked/{company_name}", tags=["Admin"]) +async def remove_tracked_company( + company_name: str, + _: UserResponse = Depends(get_current_admin), +): + """Remove a company from the tracked list (admin only).""" + db = get_db_client() + removed = db.remove_tracked_company(company_name) + if not removed: + raise HTTPException(status_code=404, detail="Company not found in tracking list") + return {"message": f"Stopped tracking {company_name}"} + + +@app.get("/admin/alerts", tags=["Admin"]) +async def list_alerts( + limit: int = Query(default=50, ge=1, le=200), + _: UserResponse = Depends(get_current_admin), +): + """List recent alerts from scheduled analysis (admin only).""" + db = get_db_client() + return db.list_alerts(limit=limit) + + # ============== Analytics Endpoint ============== diff --git a/SPARC/database.py b/SPARC/database.py index 4492311..978fba8 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -192,6 +192,35 @@ class DatabaseClient: ON jobs(status) """) + # Create tracked companies table for scheduled analysis + cursor.execute(""" + CREATE TABLE IF NOT EXISTS tracked_companies ( + id SERIAL PRIMARY KEY, + company_name VARCHAR(255) UNIQUE NOT NULL, + last_patent_count INTEGER DEFAULT 0, + last_analysis_at TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create alerts table for significant changes + cursor.execute(""" + CREATE TABLE IF NOT EXISTS alerts ( + id SERIAL PRIMARY KEY, + company_name VARCHAR(255) NOT NULL, + alert_type VARCHAR(50) NOT NULL, + message TEXT NOT NULL, + old_value NUMERIC, + new_value NUMERIC, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_alerts_company + ON alerts(company_name) + """) + self.conn.commit() @staticmethod @@ -803,3 +832,81 @@ class DatabaseClient: with conn.cursor() as cursor: cursor.execute("SELECT COUNT(*) FROM users") return cursor.fetchone()[0] + + # Tracked Companies Methods + + def add_tracked_company(self, company_name: str) -> Optional[Dict]: + """Add a company to the tracking list.""" + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + try: + cursor.execute( + "INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", + (company_name,), + ) + row = cursor.fetchone() + conn.commit() + return dict(row) if row else None + except Exception: + conn.rollback() + return None + + def remove_tracked_company(self, company_name: str) -> bool: + """Remove a company from the tracking list.""" + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)", + (company_name,), + ) + conn.commit() + return cursor.rowcount > 0 + + def list_tracked_companies(self) -> List[Dict]: + """List all tracked companies.""" + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") + return [dict(row) for row in cursor.fetchall()] + + def update_tracked_company( + self, company_name: str, patent_count: int + ) -> None: + """Update the last analysis stats for a tracked company.""" + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """UPDATE tracked_companies + SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP + WHERE LOWER(company_name) = LOWER(%s)""", + (patent_count, company_name), + ) + conn.commit() + + def store_alert( + self, + company_name: str, + alert_type: str, + message: str, + old_value: float | None = None, + new_value: float | None = None, + ) -> None: + """Record an alert for a significant change.""" + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """INSERT INTO alerts (company_name, alert_type, message, old_value, new_value) + VALUES (%s, %s, %s, %s, %s)""", + (company_name, alert_type, message, old_value, new_value), + ) + conn.commit() + + def list_alerts(self, limit: int = 50) -> List[Dict]: + """List recent alerts.""" + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s", + (limit,), + ) + return [dict(row) for row in cursor.fetchall()] diff --git a/SPARC/scheduler.py b/SPARC/scheduler.py new file mode 100644 index 0000000..5af3940 --- /dev/null +++ b/SPARC/scheduler.py @@ -0,0 +1,109 @@ +"""Scheduled patent analysis for tracked companies. + +Uses APScheduler to periodically re-analyze tracked companies and +detect significant changes in patent counts. +""" + +import logging +import os + +from SPARC import config +from SPARC.analyzer import CompanyAnalyzer +from SPARC.database import DatabaseClient + +logger = logging.getLogger(__name__) + +# Configurable via environment variable (in hours, default 24) +SCHEDULE_INTERVAL_HOURS = int(os.getenv("SCHEDULE_INTERVAL_HOURS", "24")) + +# Patent count change threshold (percentage) to trigger an alert +CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20")) + + +def run_scheduled_analysis() -> None: + """Re-analyze all tracked companies and check for significant changes.""" + db = DatabaseClient(config.database_url) + db.connect() + db.initialize_schema() + + tracked = db.list_tracked_companies() + if not tracked: + logger.info("No tracked companies configured; skipping scheduled analysis") + return + + logger.info("Running scheduled analysis for %d tracked companies", len(tracked)) + + analyzer = CompanyAnalyzer(db_client=db) + + for company_row in tracked: + name = company_row["company_name"] + old_count = company_row.get("last_patent_count", 0) or 0 + + try: + result = analyzer._analyze_company_safe(name) + + if result.success: + new_count = result.patent_count + + # Update tracking record + db.update_tracked_company(name, new_count) + + # Check for significant change + if old_count > 0: + delta_pct = abs(new_count - old_count) / old_count * 100 + if delta_pct >= CHANGE_THRESHOLD_PERCENT: + direction = "increased" if new_count > old_count else "decreased" + message = ( + f"Patent count for {name} {direction} by {delta_pct:.0f}% " + f"({old_count} -> {new_count})" + ) + logger.warning("ALERT: %s", message) + db.store_alert( + company_name=name, + alert_type="patent_count_change", + message=message, + old_value=old_count, + new_value=new_count, + ) + elif new_count > 0: + # First analysis -- record baseline + logger.info("Baseline for %s: %d patents", name, new_count) + else: + logger.warning("Scheduled analysis failed for %s: %s", name, result.error) + + except Exception as e: + logger.error("Error analyzing tracked company %s: %s", name, e) + + db.close() + logger.info("Scheduled analysis complete") + + +def start_scheduler() -> None: + """Start the APScheduler background scheduler. + + Safe to call at application startup. If apscheduler is not installed, + the function logs a warning and returns without starting anything. + """ + try: + from apscheduler.schedulers.background import BackgroundScheduler + except ImportError: + logger.warning( + "apscheduler not installed; scheduled analysis disabled. " + "Install with: pip install apscheduler" + ) + return + + scheduler = BackgroundScheduler() + scheduler.add_job( + run_scheduled_analysis, + "interval", + hours=SCHEDULE_INTERVAL_HOURS, + id="scheduled_patent_analysis", + replace_existing=True, + ) + scheduler.start() + logger.info( + "Scheduled patent analysis started (every %d hours, threshold %d%%)", + SCHEDULE_INTERVAL_HOURS, + CHANGE_THRESHOLD_PERCENT, + ) diff --git a/requirements.txt b/requirements.txt index e854576..25affa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ pandas bcrypt PyJWT slowapi +apscheduler