feat(database): add patent/serp caching tables and connection pooling

- Add patents table (patent_id PK, raw_sections JSONB, minimized_content)
- Add serp_queries table (query_hash unique, result_patent_ids, expires_at)
- Add cache methods: get/store_patent, get/store_serp_query
- Replace single connection with ThreadedConnectionPool (min=2, max=10)
- Add get_conn() context manager for thread-safe connection checkout
- Legacy single-connection path preserved for backwards compatibility
This commit is contained in:
2026-03-24 14:34:33 -04:00
parent b9bb3dc1cd
commit 3154f6b732
+146 -4
View File
@@ -1,9 +1,11 @@
"""Database client for storing and retrieving LLM messages and user authentication.""" """Database client for storing and retrieving LLM messages and user authentication."""
import contextlib
import psycopg2 import psycopg2
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extras import RealDictCursor from psycopg2.extras import RealDictCursor
from typing import Dict, List, Optional from typing import Dict, List, Optional
from datetime import datetime from datetime import datetime, timedelta
import json import json
import hashlib import hashlib
import bcrypt import bcrypt
@@ -12,24 +14,49 @@ import bcrypt
class DatabaseClient: class DatabaseClient:
"""Handles database operations for message storage and retrieval.""" """Handles database operations for message storage and retrieval."""
def __init__(self, database_url: str): def __init__(self, database_url: str, minconn: int = 2, maxconn: int = 10):
"""Initialize the database client. """Initialize the database client.
Args: Args:
database_url: PostgreSQL connection string database_url: PostgreSQL connection string
minconn: Minimum connections in the pool
maxconn: Maximum connections in the pool
""" """
self.database_url = database_url self.database_url = database_url
self._pool: ThreadedConnectionPool | None = None
self._minconn = minconn
self._maxconn = maxconn
# Legacy single connection kept for backwards compatibility
self.conn = None self.conn = None
def _ensure_pool(self):
"""Create the connection pool if it doesn't exist yet."""
if self._pool is None or self._pool.closed:
self._pool = ThreadedConnectionPool(
self._minconn, self._maxconn, self.database_url
)
@contextlib.contextmanager
def get_conn(self):
"""Check out a connection from the pool. Returns it on exit."""
self._ensure_pool()
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
def connect(self): def connect(self):
"""Establish database connection.""" """Establish database connection (legacy single-connection path)."""
if not self.conn or self.conn.closed: if not self.conn or self.conn.closed:
self.conn = psycopg2.connect(self.database_url) self.conn = psycopg2.connect(self.database_url)
def close(self): def close(self):
"""Close database connection.""" """Close database connection and pool."""
if self.conn and not self.conn.closed: if self.conn and not self.conn.closed:
self.conn.close() self.conn.close()
if self._pool and not self._pool.closed:
self._pool.closeall()
def initialize_schema(self): def initialize_schema(self):
"""Create database tables if they don't exist.""" """Create database tables if they don't exist."""
@@ -110,6 +137,40 @@ class DatabaseClient:
ON users(email) ON users(email)
""") """)
# Create patents cache table
cursor.execute("""
CREATE TABLE IF NOT EXISTS patents (
patent_id VARCHAR(64) PRIMARY KEY,
company_name VARCHAR(255),
pdf_link TEXT,
raw_sections JSONB,
minimized_content TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_patents_company
ON patents(company_name)
""")
# Create SERP query cache table
cursor.execute("""
CREATE TABLE IF NOT EXISTS serp_queries (
id SERIAL PRIMARY KEY,
company_name VARCHAR(255),
query_hash VARCHAR(64) UNIQUE,
result_patent_ids TEXT[],
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_serp_queries_hash
ON serp_queries(query_hash)
""")
self.conn.commit() self.conn.commit()
@staticmethod @staticmethod
@@ -320,6 +381,87 @@ class DatabaseClient:
"period_days": days, "period_days": days,
} }
# Patent Cache Methods
def get_cached_patent(self, patent_id: str) -> Optional[Dict]:
"""Look up a cached patent by ID.
Returns:
Dict with raw_sections and minimized_content, or None.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM patents WHERE patent_id = %s",
(patent_id,),
)
row = cursor.fetchone()
return dict(row) if row else None
def store_patent(
self,
patent_id: str,
company_name: str,
pdf_link: str,
raw_sections: Dict,
minimized_content: str,
) -> None:
"""Store a processed patent in the cache."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (patent_id) DO UPDATE SET
raw_sections = EXCLUDED.raw_sections,
minimized_content = EXCLUDED.minimized_content
""",
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
)
conn.commit()
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
"""Look up cached SERP query results.
Returns:
List of patent IDs if cache hit and not expired, None otherwise.
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT result_patent_ids FROM serp_queries
WHERE query_hash = %s AND expires_at > NOW()
""",
(query_hash,),
)
row = cursor.fetchone()
return row["result_patent_ids"] if row else None
def store_serp_query(
self,
company_name: str,
query_hash: str,
patent_ids: List[str],
ttl_hours: int = 24,
) -> None:
"""Store SERP query results in the cache."""
expires_at = datetime.now() + timedelta(hours=ttl_hours)
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO serp_queries (company_name, query_hash, result_patent_ids, expires_at)
VALUES (%s, %s, %s, %s)
ON CONFLICT (query_hash) DO UPDATE SET
result_patent_ids = EXCLUDED.result_patent_ids,
expires_at = EXCLUDED.expires_at
""",
(company_name, query_hash, patent_ids, expires_at),
)
conn.commit()
# User Authentication Methods # User Authentication Methods
@staticmethod @staticmethod