Files
SPARC/SPARC/database.py
T

938 lines
32 KiB
Python

"""Database client for storing and retrieving LLM messages and user authentication."""
import contextlib
import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import bcrypt
import psycopg2
from psycopg2.extras import RealDictCursor
from psycopg2.pool import ThreadedConnectionPool
class DatabaseClient:
"""Handles database operations for message storage and retrieval."""
def __init__(self, database_url: str, minconn: int = 2, maxconn: int = 10):
"""Initialize the database client.
Args:
database_url: PostgreSQL connection string
minconn: Minimum connections in the pool
maxconn: Maximum connections in the pool
"""
self.database_url = database_url
self._pool: ThreadedConnectionPool | None = None
self._minconn = minconn
self._maxconn = maxconn
# Legacy single connection kept for backwards compatibility
self.conn = None
def _ensure_pool(self):
"""Create the connection pool if it doesn't exist yet."""
if self._pool is None or self._pool.closed:
self._pool = ThreadedConnectionPool(
self._minconn, self._maxconn, self.database_url
)
@contextlib.contextmanager
def get_conn(self):
"""Check out a connection from the pool. Returns it on exit."""
self._ensure_pool()
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
def connect(self):
"""Establish database connection (legacy single-connection path)."""
if not self.conn or self.conn.closed:
self.conn = psycopg2.connect(self.database_url)
def close(self):
"""Close database connection and pool."""
if self.conn and not self.conn.closed:
self.conn.close()
if self._pool and not self._pool.closed:
self._pool.closeall()
def initialize_schema(self):
"""Create database tables if they don't exist."""
self.connect()
with self.conn.cursor() as cursor:
# Create messages table
cursor.execute("""
CREATE TABLE IF NOT EXISTS llm_messages (
id SERIAL PRIMARY KEY,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
company_name VARCHAR(255),
analysis_type VARCHAR(50),
model VARCHAR(100),
prompt TEXT NOT NULL,
prompt_hash VARCHAR(64),
response TEXT,
metadata JSONB,
token_usage JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_cached BOOLEAN DEFAULT FALSE
)
""")
# Create index on timestamp for analytics queries
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_timestamp
ON llm_messages(timestamp)
""")
# Create index on company_name for filtering
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_company
ON llm_messages(company_name)
""")
# Add prompt_hash and is_cached columns if they don't exist (for existing tables)
# This must run BEFORE creating the index on prompt_hash
cursor.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'prompt_hash'
) THEN
ALTER TABLE llm_messages ADD COLUMN prompt_hash VARCHAR(64);
END IF;
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'llm_messages' AND column_name = 'is_cached'
) THEN
ALTER TABLE llm_messages ADD COLUMN is_cached BOOLEAN DEFAULT FALSE;
END IF;
END $$;
""")
# Create index on prompt_hash for cache lookups
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_prompt_hash
ON llm_messages(prompt_hash)
""")
# Create users table for authentication
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
role VARCHAR(20) DEFAULT 'user' CHECK (role IN ('admin', 'user')),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create index on email for fast lookups
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_users_email
ON users(email)
""")
# Create patents cache table
cursor.execute("""
CREATE TABLE IF NOT EXISTS patents (
patent_id VARCHAR(64) PRIMARY KEY,
company_name VARCHAR(255),
pdf_link TEXT,
raw_sections JSONB,
minimized_content TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_patents_company
ON patents(company_name)
""")
# Create SERP query cache table
cursor.execute("""
CREATE TABLE IF NOT EXISTS serp_queries (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255),
query_hash VARCHAR(64) UNIQUE,
result_patent_ids TEXT[],
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_serp_queries_hash
ON serp_queries(query_hash)
""")
# Create jobs table for persisting async batch job state
cursor.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id VARCHAR(128) PRIMARY KEY,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress INTEGER NOT NULL DEFAULT 0,
total_companies INTEGER NOT NULL DEFAULT 0,
completed_companies INTEGER NOT NULL DEFAULT 0,
result_json JSONB,
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_status
ON jobs(status)
""")
# Create tracked companies table for scheduled analysis
cursor.execute("""
CREATE TABLE IF NOT EXISTS tracked_companies (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) UNIQUE NOT NULL,
last_patent_count INTEGER DEFAULT 0,
last_analysis_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create alerts table for significant changes
cursor.execute("""
CREATE TABLE IF NOT EXISTS alerts (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255) NOT NULL,
alert_type VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
old_value NUMERIC,
new_value NUMERIC,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_alerts_company
ON alerts(company_name)
""")
self.conn.commit()
@staticmethod
def hash_prompt(prompt: str) -> str:
"""Generate a hash of the prompt for cache lookups.
Args:
prompt: The prompt text to hash
Returns:
SHA-256 hash of the prompt
"""
return hashlib.sha256(prompt.encode()).hexdigest()
def get_cached_response(
self,
prompt: str,
company_name: Optional[str] = None,
analysis_type: Optional[str] = None,
) -> Optional[Dict]:
"""Look up a cached response for a given prompt.
Args:
prompt: The prompt to look up
company_name: Optional company name filter
analysis_type: Optional analysis type filter
Returns:
Cached message dict if found, None otherwise
"""
prompt_hash = self.hash_prompt(prompt)
query = """
SELECT * FROM llm_messages
WHERE prompt_hash = %s
AND response IS NOT NULL
AND response NOT LIKE '[DATABASE MODE]%%'
AND response NOT LIKE '[TEST MODE]%%'
AND response NOT LIKE '[NO API]%%'
"""
params = [prompt_hash]
if company_name:
query += " AND company_name = %s"
params.append(company_name)
if analysis_type:
query += " AND analysis_type = %s"
params.append(analysis_type)
query += " ORDER BY timestamp DESC LIMIT 1"
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
result = cursor.fetchone()
return dict(result) if result else None
def store_message(
self,
prompt: str,
response: str,
company_name: Optional[str] = None,
analysis_type: Optional[str] = None,
model: Optional[str] = None,
metadata: Optional[Dict] = None,
token_usage: Optional[Dict] = None,
is_cached: bool = False,
) -> int:
"""Store an LLM message exchange in the database.
Args:
prompt: The prompt sent to the LLM
response: The response from the LLM
company_name: Name of company being analyzed
analysis_type: Type of analysis (e.g., 'single_patent', 'portfolio')
model: Model identifier used
metadata: Additional metadata as dict
token_usage: Token usage information
is_cached: Whether this response was served from cache
Returns:
The ID of the inserted record
"""
prompt_hash = self.hash_prompt(prompt)
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
prompt,
prompt_hash,
response,
company_name,
analysis_type,
model,
json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None,
is_cached,
),
)
message_id = cursor.fetchone()[0]
conn.commit()
return message_id
def get_messages(
self,
company_name: Optional[str] = None,
analysis_type: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Retrieve messages from the database.
Args:
company_name: Filter by company name
analysis_type: Filter by analysis type
limit: Maximum number of records to return
offset: Number of records to skip
Returns:
List of message dictionaries
"""
query = "SELECT * FROM llm_messages WHERE 1=1"
params = []
if company_name:
query += " AND company_name = %s"
params.append(company_name)
if analysis_type:
query += " AND analysis_type = %s"
params.append(analysis_type)
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset])
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def get_analytics(self, days: int = 30) -> Dict:
"""Get analytics on message usage.
Args:
days: Number of days to look back
Returns:
Dictionary with analytics data
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
""",
(days,),
)
total = cursor.fetchone()["total_messages"]
# Messages by company
cursor.execute(
"""
SELECT company_name, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY company_name
ORDER BY count DESC
LIMIT 10
""",
(days,),
)
by_company = cursor.fetchall()
# Messages by type
cursor.execute(
"""
SELECT analysis_type, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
)
by_type = cursor.fetchall()
return {
"total_messages": total,
"by_company": [dict(row) for row in by_company],
"by_type": [dict(row) for row in by_type],
"period_days": days,
}
# Patent Cache Methods
def get_cached_patent(self, patent_id: str) -> Optional[Dict]:
"""Look up a cached patent by ID.
Returns:
Dict with raw_sections and minimized_content, or None.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM patents WHERE patent_id = %s",
(patent_id,),
)
row = cursor.fetchone()
return dict(row) if row else None
def store_patent(
self,
patent_id: str,
company_name: str,
pdf_link: str,
raw_sections: Dict,
minimized_content: str,
) -> None:
"""Store a processed patent in the cache."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (patent_id) DO UPDATE SET
raw_sections = EXCLUDED.raw_sections,
minimized_content = EXCLUDED.minimized_content
""",
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
)
conn.commit()
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
"""Look up cached SERP query results.
Returns:
List of patent IDs if cache hit and not expired, None otherwise.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT result_patent_ids FROM serp_queries
WHERE query_hash = %s AND expires_at > NOW()
""",
(query_hash,),
)
row = cursor.fetchone()
return row["result_patent_ids"] if row else None
def store_serp_query(
self,
company_name: str,
query_hash: str,
patent_ids: List[str],
ttl_hours: int = 24,
) -> None:
"""Store SERP query results in the cache."""
expires_at = datetime.now() + timedelta(hours=ttl_hours)
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO serp_queries (company_name, query_hash, result_patent_ids, expires_at)
VALUES (%s, %s, %s, %s)
ON CONFLICT (query_hash) DO UPDATE SET
result_patent_ids = EXCLUDED.result_patent_ids,
expires_at = EXCLUDED.expires_at
""",
(company_name, query_hash, patent_ids, expires_at),
)
conn.commit()
# Job Persistence Methods
def create_job(
self,
job_id: str,
total_companies: int,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
Returns:
Job dict
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
RETURNING *
""",
(job_id, total_companies),
)
job = cursor.fetchone()
conn.commit()
return dict(job)
def update_job(
self,
job_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
completed_companies: Optional[int] = None,
result_json: Optional[str] = None,
error: Optional[str] = None,
) -> Optional[Dict]:
"""Update a job's state.
Only non-None fields are updated.
"""
updates = []
params = []
if status is not None:
updates.append("status = %s")
params.append(status)
if progress is not None:
updates.append("progress = %s")
params.append(progress)
if completed_companies is not None:
updates.append("completed_companies = %s")
params.append(completed_companies)
if result_json is not None:
updates.append("result_json = %s")
params.append(result_json)
if error is not None:
updates.append("error = %s")
params.append(error)
if not updates:
return self.get_job(job_id)
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(job_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
params,
)
job = cursor.fetchone()
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
job = cursor.fetchone()
return dict(job) if job else None
def list_jobs(
self,
status: Optional[str] = None,
limit: int = 10,
cursor: Optional[str] = None,
) -> List[Dict]:
"""List jobs with optional status filter and cursor-based pagination.
Args:
status: Optional status filter (pending, running, completed, failed).
limit: Maximum number of jobs to return.
cursor: Opaque cursor (``created_at|job_id``) from a previous
response. When provided, only jobs older than the cursor are
returned.
Returns:
List of job dicts ordered by created_at descending.
"""
conditions: list[str] = []
params: list = []
if status:
conditions.append("status = %s")
params.append(status)
if cursor:
try:
ts_str, cursor_job_id = cursor.rsplit("|", 1)
conditions.append("(created_at, job_id) < (%s, %s)")
params.extend([ts_str, cursor_job_id])
except ValueError:
pass # Ignore malformed cursors; return from start
query = "SELECT * FROM jobs"
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY created_at DESC, job_id DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, params)
return [dict(row) for row in cur.fetchall()]
def mark_stale_jobs_failed(self) -> int:
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
Called at startup to clean up jobs that were interrupted by a restart.
Returns:
Number of jobs marked as failed.
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
updated_at = CURRENT_TIMESTAMP
WHERE status IN ('running', 'pending')
"""
)
count = cursor.rowcount
conn.commit()
return count
# User Authentication Methods
@staticmethod
def hash_password(password: str) -> str:
"""Hash a password using bcrypt.
Args:
password: Plain text password
Returns:
Hashed password string
"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@staticmethod
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a password against its hash.
Args:
password: Plain text password
password_hash: Stored hash
Returns:
True if password matches
"""
return bcrypt.checkpw(password.encode(), password_hash.encode())
def create_user(
self,
email: str,
password: str,
role: str = "user",
) -> Optional[Dict]:
"""Create a new user.
Args:
email: User email
password: Plain text password
role: User role ('admin' or 'user')
Returns:
Created user dict or None if email exists
"""
password_hash = self.hash_password(password)
try:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO users (email, password_hash, role)
VALUES (%s, %s, %s)
RETURNING id, email, role, created_at
""",
(email, password_hash, role),
)
user = cursor.fetchone()
conn.commit()
return dict(user) if user else None
except psycopg2.errors.UniqueViolation:
return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
"""Authenticate a user by email and password.
Args:
email: User email
password: Plain text password
Returns:
User dict if authenticated, None otherwise
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
if user and self.verify_password(password, user["password_hash"]):
return {
"id": user["id"],
"email": user["email"],
"role": user["role"],
"created_at": user["created_at"],
}
return None
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
"""Get a user by ID.
Args:
user_id: User ID
Returns:
User dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,),
)
user = cursor.fetchone()
return dict(user) if user else None
def get_user_by_email(self, email: str) -> Optional[Dict]:
"""Get a user by email.
Args:
email: User email
Returns:
User dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
return dict(user) if user else None
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
"""Get all users (admin only).
Args:
limit: Maximum number of users
offset: Offset for pagination
Returns:
List of user dicts
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, email, role, created_at
FROM users
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
"""Update a user's role (admin only).
Args:
user_id: User ID
role: New role ('admin' or 'user')
Returns:
Updated user dict or None
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
UPDATE users
SET role = %s, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
RETURNING id, email, role, created_at
""",
(role, user_id),
)
user = cursor.fetchone()
conn.commit()
return dict(user) if user else None
def delete_user(self, user_id: int) -> bool:
"""Delete a user (admin only).
Args:
user_id: User ID
Returns:
True if deleted
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0
conn.commit()
return deleted
def get_user_count(self) -> int:
"""Get total user count.
Returns:
Number of users
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
# Tracked Companies Methods
def add_tracked_company(self, company_name: str) -> Optional[Dict]:
"""Add a company to the tracking list."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
try:
cursor.execute(
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *",
(company_name,),
)
row = cursor.fetchone()
conn.commit()
return dict(row) if row else None
except Exception:
conn.rollback()
return None
def remove_tracked_company(self, company_name: str) -> bool:
"""Remove a company from the tracking list."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
(company_name,),
)
conn.commit()
return cursor.rowcount > 0
def list_tracked_companies(self) -> List[Dict]:
"""List all tracked companies."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name")
return [dict(row) for row in cursor.fetchall()]
def update_tracked_company(
self, company_name: str, patent_count: int
) -> None:
"""Update the last analysis stats for a tracked company."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""UPDATE tracked_companies
SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP
WHERE LOWER(company_name) = LOWER(%s)""",
(patent_count, company_name),
)
conn.commit()
def store_alert(
self,
company_name: str,
alert_type: str,
message: str,
old_value: float | None = None,
new_value: float | None = None,
) -> None:
"""Record an alert for a significant change."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""INSERT INTO alerts (company_name, alert_type, message, old_value, new_value)
VALUES (%s, %s, %s, %s, %s)""",
(company_name, alert_type, message, old_value, new_value),
)
conn.commit()
def list_alerts(self, limit: int = 50) -> List[Dict]:
"""List recent alerts."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s",
(limit,),
)
return [dict(row) for row in cursor.fetchall()]