"""Database client for storing and retrieving LLM messages and user authentication.""" import contextlib import hashlib import json from datetime import datetime, timedelta from typing import Dict, List, Optional import bcrypt import psycopg2 from psycopg2.extras import RealDictCursor from psycopg2.pool import ThreadedConnectionPool class DatabaseClient: """Handles database operations for message storage and retrieval.""" def __init__(self, database_url: str, minconn: int = 2, maxconn: int = 10): """Initialize the database client. Args: database_url: PostgreSQL connection string minconn: Minimum connections in the pool maxconn: Maximum connections in the pool """ self.database_url = database_url self._pool: ThreadedConnectionPool | None = None self._minconn = minconn self._maxconn = maxconn # Legacy single connection kept for backwards compatibility self.conn = None def _ensure_pool(self): """Create the connection pool if it doesn't exist yet.""" if self._pool is None or self._pool.closed: self._pool = ThreadedConnectionPool( self._minconn, self._maxconn, self.database_url ) @contextlib.contextmanager def get_conn(self): """Check out a connection from the pool. Returns it on exit.""" self._ensure_pool() conn = self._pool.getconn() try: yield conn finally: self._pool.putconn(conn) def connect(self): """Establish database connection (legacy single-connection path).""" if not self.conn or self.conn.closed: self.conn = psycopg2.connect(self.database_url) def close(self): """Close database connection and pool.""" if self.conn and not self.conn.closed: self.conn.close() if self._pool and not self._pool.closed: self._pool.closeall() def initialize_schema(self): """Create database tables if they don't exist.""" self.connect() with self.conn.cursor() as cursor: # Create messages table cursor.execute(""" CREATE TABLE IF NOT EXISTS llm_messages ( id SERIAL PRIMARY KEY, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, company_name VARCHAR(255), analysis_type VARCHAR(50), model VARCHAR(100), prompt TEXT NOT NULL, prompt_hash VARCHAR(64), response TEXT, metadata JSONB, token_usage JSONB, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_cached BOOLEAN DEFAULT FALSE ) """) # Create index on timestamp for analytics queries cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON llm_messages(timestamp) """) # Create index on company_name for filtering cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_company ON llm_messages(company_name) """) # Add prompt_hash and is_cached columns if they don't exist (for existing tables) # This must run BEFORE creating the index on prompt_hash cursor.execute(""" DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'llm_messages' AND column_name = 'prompt_hash' ) THEN ALTER TABLE llm_messages ADD COLUMN prompt_hash VARCHAR(64); END IF; IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'llm_messages' AND column_name = 'is_cached' ) THEN ALTER TABLE llm_messages ADD COLUMN is_cached BOOLEAN DEFAULT FALSE; END IF; END $$; """) # Create index on prompt_hash for cache lookups cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_prompt_hash ON llm_messages(prompt_hash) """) # Create users table for authentication cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, email VARCHAR(255) UNIQUE NOT NULL, password_hash VARCHAR(255) NOT NULL, role VARCHAR(20) DEFAULT 'user' CHECK (role IN ('admin', 'user')), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create index on email for fast lookups cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_users_email ON users(email) """) # Create patents cache table cursor.execute(""" CREATE TABLE IF NOT EXISTS patents ( patent_id VARCHAR(64) PRIMARY KEY, company_name VARCHAR(255), pdf_link TEXT, raw_sections JSONB, minimized_content TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_patents_company ON patents(company_name) """) # Create SERP query cache table cursor.execute(""" CREATE TABLE IF NOT EXISTS serp_queries ( id SERIAL PRIMARY KEY, company_name VARCHAR(255), query_hash VARCHAR(64) UNIQUE, result_patent_ids TEXT[], expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_serp_queries_hash ON serp_queries(query_hash) """) # Create jobs table for persisting async batch job state cursor.execute(""" CREATE TABLE IF NOT EXISTS jobs ( job_id VARCHAR(128) PRIMARY KEY, status VARCHAR(20) NOT NULL DEFAULT 'pending', progress INTEGER NOT NULL DEFAULT 0, total_companies INTEGER NOT NULL DEFAULT 0, completed_companies INTEGER NOT NULL DEFAULT 0, result_json JSONB, error TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status) """) # Create tracked companies table for scheduled analysis cursor.execute(""" CREATE TABLE IF NOT EXISTS tracked_companies ( id SERIAL PRIMARY KEY, company_name VARCHAR(255) UNIQUE NOT NULL, last_patent_count INTEGER DEFAULT 0, last_analysis_at TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create alerts table for significant changes cursor.execute(""" CREATE TABLE IF NOT EXISTS alerts ( id SERIAL PRIMARY KEY, company_name VARCHAR(255) NOT NULL, alert_type VARCHAR(50) NOT NULL, message TEXT NOT NULL, old_value NUMERIC, new_value NUMERIC, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_alerts_company ON alerts(company_name) """) self.conn.commit() @staticmethod def hash_prompt(prompt: str) -> str: """Generate a hash of the prompt for cache lookups. Args: prompt: The prompt text to hash Returns: SHA-256 hash of the prompt """ return hashlib.sha256(prompt.encode()).hexdigest() def get_cached_response( self, prompt: str, company_name: Optional[str] = None, analysis_type: Optional[str] = None, ) -> Optional[Dict]: """Look up a cached response for a given prompt. Args: prompt: The prompt to look up company_name: Optional company name filter analysis_type: Optional analysis type filter Returns: Cached message dict if found, None otherwise """ prompt_hash = self.hash_prompt(prompt) query = """ SELECT * FROM llm_messages WHERE prompt_hash = %s AND response IS NOT NULL AND response NOT LIKE '[DATABASE MODE]%%' AND response NOT LIKE '[TEST MODE]%%' AND response NOT LIKE '[NO API]%%' """ params = [prompt_hash] if company_name: query += " AND company_name = %s" params.append(company_name) if analysis_type: query += " AND analysis_type = %s" params.append(analysis_type) query += " ORDER BY timestamp DESC LIMIT 1" with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params) result = cursor.fetchone() return dict(result) if result else None def store_message( self, prompt: str, response: str, company_name: Optional[str] = None, analysis_type: Optional[str] = None, model: Optional[str] = None, metadata: Optional[Dict] = None, token_usage: Optional[Dict] = None, is_cached: bool = False, ) -> int: """Store an LLM message exchange in the database. Args: prompt: The prompt sent to the LLM response: The response from the LLM company_name: Name of company being analyzed analysis_type: Type of analysis (e.g., 'single_patent', 'portfolio') model: Model identifier used metadata: Additional metadata as dict token_usage: Token usage information is_cached: Whether this response was served from cache Returns: The ID of the inserted record """ prompt_hash = self.hash_prompt(prompt) with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """ INSERT INTO llm_messages (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( prompt, prompt_hash, response, company_name, analysis_type, model, json.dumps(metadata) if metadata else None, json.dumps(token_usage) if token_usage else None, is_cached, ), ) message_id = cursor.fetchone()[0] conn.commit() return message_id def get_messages( self, company_name: Optional[str] = None, analysis_type: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> List[Dict]: """Retrieve messages from the database. Args: company_name: Filter by company name analysis_type: Filter by analysis type limit: Maximum number of records to return offset: Number of records to skip Returns: List of message dictionaries """ query = "SELECT * FROM llm_messages WHERE 1=1" params = [] if company_name: query += " AND company_name = %s" params.append(company_name) if analysis_type: query += " AND analysis_type = %s" params.append(analysis_type) query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" params.extend([limit, offset]) with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] def get_analytics(self, days: int = 30) -> Dict: """Get analytics on message usage. Args: days: Number of days to look back Returns: Dictionary with analytics data """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: # Total messages cursor.execute( """ SELECT COUNT(*) as total_messages FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' """, (days,), ) total = cursor.fetchone()["total_messages"] # Messages by company cursor.execute( """ SELECT company_name, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' GROUP BY company_name ORDER BY count DESC LIMIT 10 """, (days,), ) by_company = cursor.fetchall() # Messages by type cursor.execute( """ SELECT analysis_type, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' GROUP BY analysis_type ORDER BY count DESC """, (days,), ) by_type = cursor.fetchall() return { "total_messages": total, "by_company": [dict(row) for row in by_company], "by_type": [dict(row) for row in by_type], "period_days": days, } # Patent Cache Methods def get_cached_patent(self, patent_id: str) -> Optional[Dict]: """Look up a cached patent by ID. Returns: Dict with raw_sections and minimized_content, or None. """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( "SELECT * FROM patents WHERE patent_id = %s", (patent_id,), ) row = cursor.fetchone() return dict(row) if row else None def store_patent( self, patent_id: str, company_name: str, pdf_link: str, raw_sections: Dict, minimized_content: str, ) -> None: """Store a processed patent in the cache.""" with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """ INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content) VALUES (%s, %s, %s, %s, %s) ON CONFLICT (patent_id) DO UPDATE SET raw_sections = EXCLUDED.raw_sections, minimized_content = EXCLUDED.minimized_content """, (patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content), ) conn.commit() def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]: """Look up cached SERP query results. Returns: List of patent IDs if cache hit and not expired, None otherwise. """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ SELECT result_patent_ids FROM serp_queries WHERE query_hash = %s AND expires_at > NOW() """, (query_hash,), ) row = cursor.fetchone() return row["result_patent_ids"] if row else None def store_serp_query( self, company_name: str, query_hash: str, patent_ids: List[str], ttl_hours: int = 24, ) -> None: """Store SERP query results in the cache.""" expires_at = datetime.now() + timedelta(hours=ttl_hours) with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """ INSERT INTO serp_queries (company_name, query_hash, result_patent_ids, expires_at) VALUES (%s, %s, %s, %s) ON CONFLICT (query_hash) DO UPDATE SET result_patent_ids = EXCLUDED.result_patent_ids, expires_at = EXCLUDED.expires_at """, (company_name, query_hash, patent_ids, expires_at), ) conn.commit() # Job Persistence Methods def create_job( self, job_id: str, total_companies: int, ) -> Dict: """Create a new job record. Args: job_id: Unique job identifier total_companies: Number of companies in the batch Returns: Job dict """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) VALUES (%s, 'pending', 0, %s, 0) RETURNING * """, (job_id, total_companies), ) job = cursor.fetchone() conn.commit() return dict(job) def update_job( self, job_id: str, status: Optional[str] = None, progress: Optional[int] = None, completed_companies: Optional[int] = None, result_json: Optional[str] = None, error: Optional[str] = None, ) -> Optional[Dict]: """Update a job's state. Only non-None fields are updated. """ updates = [] params = [] if status is not None: updates.append("status = %s") params.append(status) if progress is not None: updates.append("progress = %s") params.append(progress) if completed_companies is not None: updates.append("completed_companies = %s") params.append(completed_companies) if result_json is not None: updates.append("result_json = %s") params.append(result_json) if error is not None: updates.append("error = %s") params.append(error) if not updates: return self.get_job(job_id) updates.append("updated_at = CURRENT_TIMESTAMP") params.append(job_id) with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *", params, ) job = cursor.fetchone() conn.commit() return dict(job) if job else None def get_job(self, job_id: str) -> Optional[Dict]: """Get a job by ID.""" with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) job = cursor.fetchone() return dict(job) if job else None def list_jobs( self, status: Optional[str] = None, limit: int = 10, cursor: Optional[str] = None, ) -> List[Dict]: """List jobs with optional status filter and cursor-based pagination. Args: status: Optional status filter (pending, running, completed, failed). limit: Maximum number of jobs to return. cursor: Opaque cursor (``created_at|job_id``) from a previous response. When provided, only jobs older than the cursor are returned. Returns: List of job dicts ordered by created_at descending. """ conditions: list[str] = [] params: list = [] if status: conditions.append("status = %s") params.append(status) if cursor: try: ts_str, cursor_job_id = cursor.rsplit("|", 1) conditions.append("(created_at, job_id) < (%s, %s)") params.extend([ts_str, cursor_job_id]) except ValueError: pass # Ignore malformed cursors; return from start query = "SELECT * FROM jobs" if conditions: query += " WHERE " + " AND ".join(conditions) query += " ORDER BY created_at DESC, job_id DESC LIMIT %s" params.append(limit) with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query, params) return [dict(row) for row in cur.fetchall()] def mark_stale_jobs_failed(self) -> int: """Mark any jobs in 'running' or 'pending' state as 'failed'. Called at startup to clean up jobs that were interrupted by a restart. Returns: Number of jobs marked as failed. """ with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """ UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart', updated_at = CURRENT_TIMESTAMP WHERE status IN ('running', 'pending') """ ) count = cursor.rowcount conn.commit() return count # User Authentication Methods @staticmethod def hash_password(password: str) -> str: """Hash a password using bcrypt. Args: password: Plain text password Returns: Hashed password string """ return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() @staticmethod def verify_password(password: str, password_hash: str) -> bool: """Verify a password against its hash. Args: password: Plain text password password_hash: Stored hash Returns: True if password matches """ return bcrypt.checkpw(password.encode(), password_hash.encode()) def create_user( self, email: str, password: str, role: str = "user", ) -> Optional[Dict]: """Create a new user. Args: email: User email password: Plain text password role: User role ('admin' or 'user') Returns: Created user dict or None if email exists """ password_hash = self.hash_password(password) try: with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ INSERT INTO users (email, password_hash, role) VALUES (%s, %s, %s) RETURNING id, email, role, created_at """, (email, password_hash, role), ) user = cursor.fetchone() conn.commit() return dict(user) if user else None except psycopg2.errors.UniqueViolation: return None def authenticate_user(self, email: str, password: str) -> Optional[Dict]: """Authenticate a user by email and password. Args: email: User email password: Plain text password Returns: User dict if authenticated, None otherwise """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( "SELECT * FROM users WHERE email = %s", (email,), ) user = cursor.fetchone() if user and self.verify_password(password, user["password_hash"]): return { "id": user["id"], "email": user["email"], "role": user["role"], "created_at": user["created_at"], } return None def get_user_by_id(self, user_id: int) -> Optional[Dict]: """Get a user by ID. Args: user_id: User ID Returns: User dict or None """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( "SELECT id, email, role, created_at FROM users WHERE id = %s", (user_id,), ) user = cursor.fetchone() return dict(user) if user else None def get_user_by_email(self, email: str) -> Optional[Dict]: """Get a user by email. Args: email: User email Returns: User dict or None """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( "SELECT id, email, role, created_at FROM users WHERE email = %s", (email,), ) user = cursor.fetchone() return dict(user) if user else None def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]: """Get all users (admin only). Args: limit: Maximum number of users offset: Offset for pagination Returns: List of user dicts """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ SELECT id, email, role, created_at FROM users ORDER BY created_at DESC LIMIT %s OFFSET %s """, (limit, offset), ) return [dict(row) for row in cursor.fetchall()] def update_user_role(self, user_id: int, role: str) -> Optional[Dict]: """Update a user's role (admin only). Args: user_id: User ID role: New role ('admin' or 'user') Returns: Updated user dict or None """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ UPDATE users SET role = %s, updated_at = CURRENT_TIMESTAMP WHERE id = %s RETURNING id, email, role, created_at """, (role, user_id), ) user = cursor.fetchone() conn.commit() return dict(user) if user else None def delete_user(self, user_id: int) -> bool: """Delete a user (admin only). Args: user_id: User ID Returns: True if deleted """ with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) deleted = cursor.rowcount > 0 conn.commit() return deleted def get_user_count(self) -> int: """Get total user count. Returns: Number of users """ with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute("SELECT COUNT(*) FROM users") return cursor.fetchone()[0] # Tracked Companies Methods def add_tracked_company(self, company_name: str) -> Optional[Dict]: """Add a company to the tracking list.""" with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: try: cursor.execute( "INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", (company_name,), ) row = cursor.fetchone() conn.commit() return dict(row) if row else None except Exception: conn.rollback() return None def remove_tracked_company(self, company_name: str) -> bool: """Remove a company from the tracking list.""" with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)", (company_name,), ) conn.commit() return cursor.rowcount > 0 def list_tracked_companies(self) -> List[Dict]: """List all tracked companies.""" with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") return [dict(row) for row in cursor.fetchall()] def update_tracked_company( self, company_name: str, patent_count: int ) -> None: """Update the last analysis stats for a tracked company.""" with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """UPDATE tracked_companies SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP WHERE LOWER(company_name) = LOWER(%s)""", (patent_count, company_name), ) conn.commit() def store_alert( self, company_name: str, alert_type: str, message: str, old_value: float | None = None, new_value: float | None = None, ) -> None: """Record an alert for a significant change.""" with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """INSERT INTO alerts (company_name, alert_type, message, old_value, new_value) VALUES (%s, %s, %s, %s, %s)""", (company_name, alert_type, message, old_value, new_value), ) conn.commit() def list_alerts(self, limit: int = 50) -> List[Dict]: """List recent alerts.""" with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( "SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s", (limit,), ) return [dict(row) for row in cursor.fetchall()]