feat(database): add patent/serp caching tables and connection pooling
- Add patents table (patent_id PK, raw_sections JSONB, minimized_content) - Add serp_queries table (query_hash unique, result_patent_ids, expires_at) - Add cache methods: get/store_patent, get/store_serp_query - Replace single connection with ThreadedConnectionPool (min=2, max=10) - Add get_conn() context manager for thread-safe connection checkout - Legacy single-connection path preserved for backwards compatibility
This commit is contained in:
+146
-4
@@ -1,9 +1,11 @@
|
|||||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
from psycopg2.pool import ThreadedConnectionPool
|
||||||
from psycopg2.extras import RealDictCursor
|
from psycopg2.extras import RealDictCursor
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import bcrypt
|
import bcrypt
|
||||||
@@ -12,24 +14,49 @@ import bcrypt
|
|||||||
class DatabaseClient:
|
class DatabaseClient:
|
||||||
"""Handles database operations for message storage and retrieval."""
|
"""Handles database operations for message storage and retrieval."""
|
||||||
|
|
||||||
def __init__(self, database_url: str):
|
def __init__(self, database_url: str, minconn: int = 2, maxconn: int = 10):
|
||||||
"""Initialize the database client.
|
"""Initialize the database client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
database_url: PostgreSQL connection string
|
database_url: PostgreSQL connection string
|
||||||
|
minconn: Minimum connections in the pool
|
||||||
|
maxconn: Maximum connections in the pool
|
||||||
"""
|
"""
|
||||||
self.database_url = database_url
|
self.database_url = database_url
|
||||||
|
self._pool: ThreadedConnectionPool | None = None
|
||||||
|
self._minconn = minconn
|
||||||
|
self._maxconn = maxconn
|
||||||
|
# Legacy single connection kept for backwards compatibility
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
|
def _ensure_pool(self):
|
||||||
|
"""Create the connection pool if it doesn't exist yet."""
|
||||||
|
if self._pool is None or self._pool.closed:
|
||||||
|
self._pool = ThreadedConnectionPool(
|
||||||
|
self._minconn, self._maxconn, self.database_url
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def get_conn(self):
|
||||||
|
"""Check out a connection from the pool. Returns it on exit."""
|
||||||
|
self._ensure_pool()
|
||||||
|
conn = self._pool.getconn()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
self._pool.putconn(conn)
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""Establish database connection."""
|
"""Establish database connection (legacy single-connection path)."""
|
||||||
if not self.conn or self.conn.closed:
|
if not self.conn or self.conn.closed:
|
||||||
self.conn = psycopg2.connect(self.database_url)
|
self.conn = psycopg2.connect(self.database_url)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close database connection."""
|
"""Close database connection and pool."""
|
||||||
if self.conn and not self.conn.closed:
|
if self.conn and not self.conn.closed:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
if self._pool and not self._pool.closed:
|
||||||
|
self._pool.closeall()
|
||||||
|
|
||||||
def initialize_schema(self):
|
def initialize_schema(self):
|
||||||
"""Create database tables if they don't exist."""
|
"""Create database tables if they don't exist."""
|
||||||
@@ -110,6 +137,40 @@ class DatabaseClient:
|
|||||||
ON users(email)
|
ON users(email)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Create patents cache table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS patents (
|
||||||
|
patent_id VARCHAR(64) PRIMARY KEY,
|
||||||
|
company_name VARCHAR(255),
|
||||||
|
pdf_link TEXT,
|
||||||
|
raw_sections JSONB,
|
||||||
|
minimized_content TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patents_company
|
||||||
|
ON patents(company_name)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create SERP query cache table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS serp_queries (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
company_name VARCHAR(255),
|
||||||
|
query_hash VARCHAR(64) UNIQUE,
|
||||||
|
result_patent_ids TEXT[],
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_serp_queries_hash
|
||||||
|
ON serp_queries(query_hash)
|
||||||
|
""")
|
||||||
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -320,6 +381,87 @@ class DatabaseClient:
|
|||||||
"period_days": days,
|
"period_days": days,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Patent Cache Methods
|
||||||
|
|
||||||
|
def get_cached_patent(self, patent_id: str) -> Optional[Dict]:
|
||||||
|
"""Look up a cached patent by ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with raw_sections and minimized_content, or None.
|
||||||
|
"""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM patents WHERE patent_id = %s",
|
||||||
|
(patent_id,),
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
def store_patent(
|
||||||
|
self,
|
||||||
|
patent_id: str,
|
||||||
|
company_name: str,
|
||||||
|
pdf_link: str,
|
||||||
|
raw_sections: Dict,
|
||||||
|
minimized_content: str,
|
||||||
|
) -> None:
|
||||||
|
"""Store a processed patent in the cache."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (patent_id) DO UPDATE SET
|
||||||
|
raw_sections = EXCLUDED.raw_sections,
|
||||||
|
minimized_content = EXCLUDED.minimized_content
|
||||||
|
""",
|
||||||
|
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
|
||||||
|
"""Look up cached SERP query results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of patent IDs if cache hit and not expired, None otherwise.
|
||||||
|
"""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT result_patent_ids FROM serp_queries
|
||||||
|
WHERE query_hash = %s AND expires_at > NOW()
|
||||||
|
""",
|
||||||
|
(query_hash,),
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return row["result_patent_ids"] if row else None
|
||||||
|
|
||||||
|
def store_serp_query(
|
||||||
|
self,
|
||||||
|
company_name: str,
|
||||||
|
query_hash: str,
|
||||||
|
patent_ids: List[str],
|
||||||
|
ttl_hours: int = 24,
|
||||||
|
) -> None:
|
||||||
|
"""Store SERP query results in the cache."""
|
||||||
|
expires_at = datetime.now() + timedelta(hours=ttl_hours)
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO serp_queries (company_name, query_hash, result_patent_ids, expires_at)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
ON CONFLICT (query_hash) DO UPDATE SET
|
||||||
|
result_patent_ids = EXCLUDED.result_patent_ids,
|
||||||
|
expires_at = EXCLUDED.expires_at
|
||||||
|
""",
|
||||||
|
(company_name, query_hash, patent_ids, expires_at),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
# User Authentication Methods
|
# User Authentication Methods
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user