forked from 0xWheatyz/SPARC
fbb72fe2a5
- Add test job to build.yaml that runs pytest and ruff before building images - Add standalone test.yaml workflow for PRs - Add ruff.toml with E/F/I rules configured - Fix all ruff lint errors: sort imports, remove unused imports, fix re-exports - Build jobs now depend on test job passing (needs: test) Closes leeworks-agents/SPARC#18 Closes leeworks-agents/SPARC#19 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
806 lines
27 KiB
Python
806 lines
27 KiB
Python
"""Database client for storing and retrieving LLM messages and user authentication."""
|
|
|
|
import contextlib
|
|
import hashlib
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional
|
|
|
|
import bcrypt
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
from psycopg2.pool import ThreadedConnectionPool
|
|
|
|
|
|
class DatabaseClient:
|
|
"""Handles database operations for message storage and retrieval."""
|
|
|
|
def __init__(self, database_url: str, minconn: int = 2, maxconn: int = 10):
|
|
"""Initialize the database client.
|
|
|
|
Args:
|
|
database_url: PostgreSQL connection string
|
|
minconn: Minimum connections in the pool
|
|
maxconn: Maximum connections in the pool
|
|
"""
|
|
self.database_url = database_url
|
|
self._pool: ThreadedConnectionPool | None = None
|
|
self._minconn = minconn
|
|
self._maxconn = maxconn
|
|
# Legacy single connection kept for backwards compatibility
|
|
self.conn = None
|
|
|
|
def _ensure_pool(self):
|
|
"""Create the connection pool if it doesn't exist yet."""
|
|
if self._pool is None or self._pool.closed:
|
|
self._pool = ThreadedConnectionPool(
|
|
self._minconn, self._maxconn, self.database_url
|
|
)
|
|
|
|
@contextlib.contextmanager
|
|
def get_conn(self):
|
|
"""Check out a connection from the pool. Returns it on exit."""
|
|
self._ensure_pool()
|
|
conn = self._pool.getconn()
|
|
try:
|
|
yield conn
|
|
finally:
|
|
self._pool.putconn(conn)
|
|
|
|
def connect(self):
|
|
"""Establish database connection (legacy single-connection path)."""
|
|
if not self.conn or self.conn.closed:
|
|
self.conn = psycopg2.connect(self.database_url)
|
|
|
|
def close(self):
|
|
"""Close database connection and pool."""
|
|
if self.conn and not self.conn.closed:
|
|
self.conn.close()
|
|
if self._pool and not self._pool.closed:
|
|
self._pool.closeall()
|
|
|
|
def initialize_schema(self):
|
|
"""Create database tables if they don't exist."""
|
|
self.connect()
|
|
|
|
with self.conn.cursor() as cursor:
|
|
# Create messages table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS llm_messages (
|
|
id SERIAL PRIMARY KEY,
|
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
company_name VARCHAR(255),
|
|
analysis_type VARCHAR(50),
|
|
model VARCHAR(100),
|
|
prompt TEXT NOT NULL,
|
|
prompt_hash VARCHAR(64),
|
|
response TEXT,
|
|
metadata JSONB,
|
|
token_usage JSONB,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
is_cached BOOLEAN DEFAULT FALSE
|
|
)
|
|
""")
|
|
|
|
# Create index on timestamp for analytics queries
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_messages_timestamp
|
|
ON llm_messages(timestamp)
|
|
""")
|
|
|
|
# Create index on company_name for filtering
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_messages_company
|
|
ON llm_messages(company_name)
|
|
""")
|
|
|
|
# Add prompt_hash and is_cached columns if they don't exist (for existing tables)
|
|
# This must run BEFORE creating the index on prompt_hash
|
|
cursor.execute("""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'llm_messages' AND column_name = 'prompt_hash'
|
|
) THEN
|
|
ALTER TABLE llm_messages ADD COLUMN prompt_hash VARCHAR(64);
|
|
END IF;
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'llm_messages' AND column_name = 'is_cached'
|
|
) THEN
|
|
ALTER TABLE llm_messages ADD COLUMN is_cached BOOLEAN DEFAULT FALSE;
|
|
END IF;
|
|
END $$;
|
|
""")
|
|
|
|
# Create index on prompt_hash for cache lookups
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_messages_prompt_hash
|
|
ON llm_messages(prompt_hash)
|
|
""")
|
|
|
|
# Create users table for authentication
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id SERIAL PRIMARY KEY,
|
|
email VARCHAR(255) UNIQUE NOT NULL,
|
|
password_hash VARCHAR(255) NOT NULL,
|
|
role VARCHAR(20) DEFAULT 'user' CHECK (role IN ('admin', 'user')),
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Create index on email for fast lookups
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_users_email
|
|
ON users(email)
|
|
""")
|
|
|
|
# Create patents cache table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS patents (
|
|
patent_id VARCHAR(64) PRIMARY KEY,
|
|
company_name VARCHAR(255),
|
|
pdf_link TEXT,
|
|
raw_sections JSONB,
|
|
minimized_content TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_patents_company
|
|
ON patents(company_name)
|
|
""")
|
|
|
|
# Create SERP query cache table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS serp_queries (
|
|
id SERIAL PRIMARY KEY,
|
|
company_name VARCHAR(255),
|
|
query_hash VARCHAR(64) UNIQUE,
|
|
result_patent_ids TEXT[],
|
|
expires_at TIMESTAMP NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_serp_queries_hash
|
|
ON serp_queries(query_hash)
|
|
""")
|
|
|
|
# Create jobs table for persisting async batch job state
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS jobs (
|
|
job_id VARCHAR(128) PRIMARY KEY,
|
|
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
|
progress INTEGER NOT NULL DEFAULT 0,
|
|
total_companies INTEGER NOT NULL DEFAULT 0,
|
|
completed_companies INTEGER NOT NULL DEFAULT 0,
|
|
result_json JSONB,
|
|
error TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_status
|
|
ON jobs(status)
|
|
""")
|
|
|
|
self.conn.commit()
|
|
|
|
@staticmethod
|
|
def hash_prompt(prompt: str) -> str:
|
|
"""Generate a hash of the prompt for cache lookups.
|
|
|
|
Args:
|
|
prompt: The prompt text to hash
|
|
|
|
Returns:
|
|
SHA-256 hash of the prompt
|
|
"""
|
|
return hashlib.sha256(prompt.encode()).hexdigest()
|
|
|
|
def get_cached_response(
|
|
self,
|
|
prompt: str,
|
|
company_name: Optional[str] = None,
|
|
analysis_type: Optional[str] = None,
|
|
) -> Optional[Dict]:
|
|
"""Look up a cached response for a given prompt.
|
|
|
|
Args:
|
|
prompt: The prompt to look up
|
|
company_name: Optional company name filter
|
|
analysis_type: Optional analysis type filter
|
|
|
|
Returns:
|
|
Cached message dict if found, None otherwise
|
|
"""
|
|
prompt_hash = self.hash_prompt(prompt)
|
|
|
|
query = """
|
|
SELECT * FROM llm_messages
|
|
WHERE prompt_hash = %s
|
|
AND response IS NOT NULL
|
|
AND response NOT LIKE '[DATABASE MODE]%%'
|
|
AND response NOT LIKE '[TEST MODE]%%'
|
|
AND response NOT LIKE '[NO API]%%'
|
|
"""
|
|
params = [prompt_hash]
|
|
|
|
if company_name:
|
|
query += " AND company_name = %s"
|
|
params.append(company_name)
|
|
|
|
if analysis_type:
|
|
query += " AND analysis_type = %s"
|
|
params.append(analysis_type)
|
|
|
|
query += " ORDER BY timestamp DESC LIMIT 1"
|
|
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(query, params)
|
|
result = cursor.fetchone()
|
|
return dict(result) if result else None
|
|
|
|
def store_message(
|
|
self,
|
|
prompt: str,
|
|
response: str,
|
|
company_name: Optional[str] = None,
|
|
analysis_type: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
token_usage: Optional[Dict] = None,
|
|
is_cached: bool = False,
|
|
) -> int:
|
|
"""Store an LLM message exchange in the database.
|
|
|
|
Args:
|
|
prompt: The prompt sent to the LLM
|
|
response: The response from the LLM
|
|
company_name: Name of company being analyzed
|
|
analysis_type: Type of analysis (e.g., 'single_patent', 'portfolio')
|
|
model: Model identifier used
|
|
metadata: Additional metadata as dict
|
|
token_usage: Token usage information
|
|
is_cached: Whether this response was served from cache
|
|
|
|
Returns:
|
|
The ID of the inserted record
|
|
"""
|
|
prompt_hash = self.hash_prompt(prompt)
|
|
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO llm_messages
|
|
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING id
|
|
""",
|
|
(
|
|
prompt,
|
|
prompt_hash,
|
|
response,
|
|
company_name,
|
|
analysis_type,
|
|
model,
|
|
json.dumps(metadata) if metadata else None,
|
|
json.dumps(token_usage) if token_usage else None,
|
|
is_cached,
|
|
),
|
|
)
|
|
|
|
message_id = cursor.fetchone()[0]
|
|
conn.commit()
|
|
|
|
return message_id
|
|
|
|
def get_messages(
|
|
self,
|
|
company_name: Optional[str] = None,
|
|
analysis_type: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> List[Dict]:
|
|
"""Retrieve messages from the database.
|
|
|
|
Args:
|
|
company_name: Filter by company name
|
|
analysis_type: Filter by analysis type
|
|
limit: Maximum number of records to return
|
|
offset: Number of records to skip
|
|
|
|
Returns:
|
|
List of message dictionaries
|
|
"""
|
|
query = "SELECT * FROM llm_messages WHERE 1=1"
|
|
params = []
|
|
|
|
if company_name:
|
|
query += " AND company_name = %s"
|
|
params.append(company_name)
|
|
|
|
if analysis_type:
|
|
query += " AND analysis_type = %s"
|
|
params.append(analysis_type)
|
|
|
|
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
|
params.extend([limit, offset])
|
|
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(query, params)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def get_analytics(self, days: int = 30) -> Dict:
|
|
"""Get analytics on message usage.
|
|
|
|
Args:
|
|
days: Number of days to look back
|
|
|
|
Returns:
|
|
Dictionary with analytics data
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
# Total messages
|
|
cursor.execute(
|
|
"""
|
|
SELECT COUNT(*) as total_messages
|
|
FROM llm_messages
|
|
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
|
""",
|
|
(days,),
|
|
)
|
|
total = cursor.fetchone()["total_messages"]
|
|
|
|
# Messages by company
|
|
cursor.execute(
|
|
"""
|
|
SELECT company_name, COUNT(*) as count
|
|
FROM llm_messages
|
|
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
|
GROUP BY company_name
|
|
ORDER BY count DESC
|
|
LIMIT 10
|
|
""",
|
|
(days,),
|
|
)
|
|
by_company = cursor.fetchall()
|
|
|
|
# Messages by type
|
|
cursor.execute(
|
|
"""
|
|
SELECT analysis_type, COUNT(*) as count
|
|
FROM llm_messages
|
|
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
|
GROUP BY analysis_type
|
|
ORDER BY count DESC
|
|
""",
|
|
(days,),
|
|
)
|
|
by_type = cursor.fetchall()
|
|
|
|
return {
|
|
"total_messages": total,
|
|
"by_company": [dict(row) for row in by_company],
|
|
"by_type": [dict(row) for row in by_type],
|
|
"period_days": days,
|
|
}
|
|
|
|
# Patent Cache Methods
|
|
|
|
def get_cached_patent(self, patent_id: str) -> Optional[Dict]:
|
|
"""Look up a cached patent by ID.
|
|
|
|
Returns:
|
|
Dict with raw_sections and minimized_content, or None.
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"SELECT * FROM patents WHERE patent_id = %s",
|
|
(patent_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def store_patent(
|
|
self,
|
|
patent_id: str,
|
|
company_name: str,
|
|
pdf_link: str,
|
|
raw_sections: Dict,
|
|
minimized_content: str,
|
|
) -> None:
|
|
"""Store a processed patent in the cache."""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
ON CONFLICT (patent_id) DO UPDATE SET
|
|
raw_sections = EXCLUDED.raw_sections,
|
|
minimized_content = EXCLUDED.minimized_content
|
|
""",
|
|
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
|
|
)
|
|
conn.commit()
|
|
|
|
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
|
|
"""Look up cached SERP query results.
|
|
|
|
Returns:
|
|
List of patent IDs if cache hit and not expired, None otherwise.
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"""
|
|
SELECT result_patent_ids FROM serp_queries
|
|
WHERE query_hash = %s AND expires_at > NOW()
|
|
""",
|
|
(query_hash,),
|
|
)
|
|
row = cursor.fetchone()
|
|
return row["result_patent_ids"] if row else None
|
|
|
|
def store_serp_query(
|
|
self,
|
|
company_name: str,
|
|
query_hash: str,
|
|
patent_ids: List[str],
|
|
ttl_hours: int = 24,
|
|
) -> None:
|
|
"""Store SERP query results in the cache."""
|
|
expires_at = datetime.now() + timedelta(hours=ttl_hours)
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO serp_queries (company_name, query_hash, result_patent_ids, expires_at)
|
|
VALUES (%s, %s, %s, %s)
|
|
ON CONFLICT (query_hash) DO UPDATE SET
|
|
result_patent_ids = EXCLUDED.result_patent_ids,
|
|
expires_at = EXCLUDED.expires_at
|
|
""",
|
|
(company_name, query_hash, patent_ids, expires_at),
|
|
)
|
|
conn.commit()
|
|
|
|
# Job Persistence Methods
|
|
|
|
def create_job(
|
|
self,
|
|
job_id: str,
|
|
total_companies: int,
|
|
) -> Dict:
|
|
"""Create a new job record.
|
|
|
|
Args:
|
|
job_id: Unique job identifier
|
|
total_companies: Number of companies in the batch
|
|
|
|
Returns:
|
|
Job dict
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
|
|
VALUES (%s, 'pending', 0, %s, 0)
|
|
RETURNING *
|
|
""",
|
|
(job_id, total_companies),
|
|
)
|
|
job = cursor.fetchone()
|
|
conn.commit()
|
|
return dict(job)
|
|
|
|
def update_job(
|
|
self,
|
|
job_id: str,
|
|
status: Optional[str] = None,
|
|
progress: Optional[int] = None,
|
|
completed_companies: Optional[int] = None,
|
|
result_json: Optional[str] = None,
|
|
error: Optional[str] = None,
|
|
) -> Optional[Dict]:
|
|
"""Update a job's state.
|
|
|
|
Only non-None fields are updated.
|
|
"""
|
|
updates = []
|
|
params = []
|
|
if status is not None:
|
|
updates.append("status = %s")
|
|
params.append(status)
|
|
if progress is not None:
|
|
updates.append("progress = %s")
|
|
params.append(progress)
|
|
if completed_companies is not None:
|
|
updates.append("completed_companies = %s")
|
|
params.append(completed_companies)
|
|
if result_json is not None:
|
|
updates.append("result_json = %s")
|
|
params.append(result_json)
|
|
if error is not None:
|
|
updates.append("error = %s")
|
|
params.append(error)
|
|
|
|
if not updates:
|
|
return self.get_job(job_id)
|
|
|
|
updates.append("updated_at = CURRENT_TIMESTAMP")
|
|
params.append(job_id)
|
|
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
|
|
params,
|
|
)
|
|
job = cursor.fetchone()
|
|
conn.commit()
|
|
return dict(job) if job else None
|
|
|
|
def get_job(self, job_id: str) -> Optional[Dict]:
|
|
"""Get a job by ID."""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
|
|
job = cursor.fetchone()
|
|
return dict(job) if job else None
|
|
|
|
def list_jobs(
|
|
self,
|
|
status: Optional[str] = None,
|
|
limit: int = 10,
|
|
) -> List[Dict]:
|
|
"""List jobs, optionally filtered by status."""
|
|
query = "SELECT * FROM jobs"
|
|
params: list = []
|
|
if status:
|
|
query += " WHERE status = %s"
|
|
params.append(status)
|
|
query += " ORDER BY created_at DESC LIMIT %s"
|
|
params.append(limit)
|
|
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(query, params)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def mark_stale_jobs_failed(self) -> int:
|
|
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
|
|
|
|
Called at startup to clean up jobs that were interrupted by a restart.
|
|
|
|
Returns:
|
|
Number of jobs marked as failed.
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
|
|
updated_at = CURRENT_TIMESTAMP
|
|
WHERE status IN ('running', 'pending')
|
|
"""
|
|
)
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
return count
|
|
|
|
# User Authentication Methods
|
|
|
|
@staticmethod
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
|
|
Returns:
|
|
Hashed password string
|
|
"""
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
|
|
@staticmethod
|
|
def verify_password(password: str, password_hash: str) -> bool:
|
|
"""Verify a password against its hash.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
password_hash: Stored hash
|
|
|
|
Returns:
|
|
True if password matches
|
|
"""
|
|
return bcrypt.checkpw(password.encode(), password_hash.encode())
|
|
|
|
def create_user(
|
|
self,
|
|
email: str,
|
|
password: str,
|
|
role: str = "user",
|
|
) -> Optional[Dict]:
|
|
"""Create a new user.
|
|
|
|
Args:
|
|
email: User email
|
|
password: Plain text password
|
|
role: User role ('admin' or 'user')
|
|
|
|
Returns:
|
|
Created user dict or None if email exists
|
|
"""
|
|
password_hash = self.hash_password(password)
|
|
|
|
try:
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO users (email, password_hash, role)
|
|
VALUES (%s, %s, %s)
|
|
RETURNING id, email, role, created_at
|
|
""",
|
|
(email, password_hash, role),
|
|
)
|
|
user = cursor.fetchone()
|
|
conn.commit()
|
|
return dict(user) if user else None
|
|
except psycopg2.errors.UniqueViolation:
|
|
return None
|
|
|
|
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
|
"""Authenticate a user by email and password.
|
|
|
|
Args:
|
|
email: User email
|
|
password: Plain text password
|
|
|
|
Returns:
|
|
User dict if authenticated, None otherwise
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"SELECT * FROM users WHERE email = %s",
|
|
(email,),
|
|
)
|
|
user = cursor.fetchone()
|
|
|
|
if user and self.verify_password(password, user["password_hash"]):
|
|
return {
|
|
"id": user["id"],
|
|
"email": user["email"],
|
|
"role": user["role"],
|
|
"created_at": user["created_at"],
|
|
}
|
|
return None
|
|
|
|
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
|
"""Get a user by ID.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
User dict or None
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
|
(user_id,),
|
|
)
|
|
user = cursor.fetchone()
|
|
return dict(user) if user else None
|
|
|
|
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
|
"""Get a user by email.
|
|
|
|
Args:
|
|
email: User email
|
|
|
|
Returns:
|
|
User dict or None
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
|
(email,),
|
|
)
|
|
user = cursor.fetchone()
|
|
return dict(user) if user else None
|
|
|
|
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
|
"""Get all users (admin only).
|
|
|
|
Args:
|
|
limit: Maximum number of users
|
|
offset: Offset for pagination
|
|
|
|
Returns:
|
|
List of user dicts
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"""
|
|
SELECT id, email, role, created_at
|
|
FROM users
|
|
ORDER BY created_at DESC
|
|
LIMIT %s OFFSET %s
|
|
""",
|
|
(limit, offset),
|
|
)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
|
"""Update a user's role (admin only).
|
|
|
|
Args:
|
|
user_id: User ID
|
|
role: New role ('admin' or 'user')
|
|
|
|
Returns:
|
|
Updated user dict or None
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
cursor.execute(
|
|
"""
|
|
UPDATE users
|
|
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = %s
|
|
RETURNING id, email, role, created_at
|
|
""",
|
|
(role, user_id),
|
|
)
|
|
user = cursor.fetchone()
|
|
conn.commit()
|
|
return dict(user) if user else None
|
|
|
|
def delete_user(self, user_id: int) -> bool:
|
|
"""Delete a user (admin only).
|
|
|
|
Args:
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
True if deleted
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
|
deleted = cursor.rowcount > 0
|
|
conn.commit()
|
|
return deleted
|
|
|
|
def get_user_count(self) -> int:
|
|
"""Get total user count.
|
|
|
|
Returns:
|
|
Number of users
|
|
"""
|
|
with self.get_conn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("SELECT COUNT(*) FROM users")
|
|
return cursor.fetchone()[0]
|