feat(backend): add response caching and user management
Replace USE_DATABASE toggle with USE_CACHE for smarter LLM response handling: - Add prompt hashing for efficient cache lookups - Cache API responses in database to reduce token usage - Always store responses for analytics (cache or fresh) Add user authentication infrastructure: - User table with bcrypt password hashing - CRUD operations for user management - Role-based access control (admin/user) Dependencies: add bcrypt and PyJWT for auth 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+324
-4
@@ -1,10 +1,12 @@
|
||||
"""Database client for storing and retrieving LLM messages."""
|
||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
import hashlib
|
||||
import bcrypt
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
@@ -43,10 +45,12 @@ class DatabaseClient:
|
||||
analysis_type VARCHAR(50),
|
||||
model VARCHAR(100),
|
||||
prompt TEXT NOT NULL,
|
||||
prompt_hash VARCHAR(64),
|
||||
response TEXT,
|
||||
metadata JSONB,
|
||||
token_usage JSONB,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_cached BOOLEAN DEFAULT FALSE
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -62,8 +66,109 @@ class DatabaseClient:
|
||||
ON llm_messages(company_name)
|
||||
""")
|
||||
|
||||
# Add prompt_hash and is_cached columns if they don't exist (for existing tables)
|
||||
# This must run BEFORE creating the index on prompt_hash
|
||||
cursor.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'llm_messages' AND column_name = 'prompt_hash'
|
||||
) THEN
|
||||
ALTER TABLE llm_messages ADD COLUMN prompt_hash VARCHAR(64);
|
||||
END IF;
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'llm_messages' AND column_name = 'is_cached'
|
||||
) THEN
|
||||
ALTER TABLE llm_messages ADD COLUMN is_cached BOOLEAN DEFAULT FALSE;
|
||||
END IF;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
# Create index on prompt_hash for cache lookups
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_prompt_hash
|
||||
ON llm_messages(prompt_hash)
|
||||
""")
|
||||
|
||||
# Create users table for authentication
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(20) DEFAULT 'user' CHECK (role IN ('admin', 'user')),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Create index on email for fast lookups
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email
|
||||
ON users(email)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def hash_prompt(prompt: str) -> str:
|
||||
"""Generate a hash of the prompt for cache lookups.
|
||||
|
||||
Args:
|
||||
prompt: The prompt text to hash
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of the prompt
|
||||
"""
|
||||
return hashlib.sha256(prompt.encode()).hexdigest()
|
||||
|
||||
def get_cached_response(
|
||||
self,
|
||||
prompt: str,
|
||||
company_name: Optional[str] = None,
|
||||
analysis_type: Optional[str] = None,
|
||||
) -> Optional[Dict]:
|
||||
"""Look up a cached response for a given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to look up
|
||||
company_name: Optional company name filter
|
||||
analysis_type: Optional analysis type filter
|
||||
|
||||
Returns:
|
||||
Cached message dict if found, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
query = """
|
||||
SELECT * FROM llm_messages
|
||||
WHERE prompt_hash = %s
|
||||
AND response IS NOT NULL
|
||||
AND response NOT LIKE '[DATABASE MODE]%%'
|
||||
AND response NOT LIKE '[TEST MODE]%%'
|
||||
AND response NOT LIKE '[NO API]%%'
|
||||
"""
|
||||
params = [prompt_hash]
|
||||
|
||||
if company_name:
|
||||
query += " AND company_name = %s"
|
||||
params.append(company_name)
|
||||
|
||||
if analysis_type:
|
||||
query += " AND analysis_type = %s"
|
||||
params.append(analysis_type)
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -73,6 +178,7 @@ class DatabaseClient:
|
||||
model: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
token_usage: Optional[Dict] = None,
|
||||
is_cached: bool = False,
|
||||
) -> int:
|
||||
"""Store an LLM message exchange in the database.
|
||||
|
||||
@@ -84,28 +190,33 @@ class DatabaseClient:
|
||||
model: Model identifier used
|
||||
metadata: Additional metadata as dict
|
||||
token_usage: Token usage information
|
||||
is_cached: Whether this response was served from cache
|
||||
|
||||
Returns:
|
||||
The ID of the inserted record
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, response, company_name, analysis_type, model, metadata, token_usage)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -208,3 +319,212 @@ class DatabaseClient:
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
# User Authentication Methods
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt.
|
||||
|
||||
Args:
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
@staticmethod
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a password against its hash.
|
||||
|
||||
Args:
|
||||
password: Plain text password
|
||||
password_hash: Stored hash
|
||||
|
||||
Returns:
|
||||
True if password matches
|
||||
"""
|
||||
return bcrypt.checkpw(password.encode(), password_hash.encode())
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
role: str = "user",
|
||||
) -> Optional[Dict]:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
password: Plain text password
|
||||
role: User role ('admin' or 'user')
|
||||
|
||||
Returns:
|
||||
Created user dict or None if email exists
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
return dict(user) if user else None
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
self.conn.rollback()
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||
"""Authenticate a user by email and password.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
User dict if authenticated, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
||||
"""Get a user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
||||
"""Get a user by email.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
||||
"""Get all users (admin only).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of users
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of user dicts
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
||||
"""Update a user's role (admin only).
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
role: New role ('admin' or 'user')
|
||||
|
||||
Returns:
|
||||
Updated user dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
return dict(user) if user else None
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
"""Delete a user (admin only).
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
self.conn.commit()
|
||||
return deleted
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
"""Get total user count.
|
||||
|
||||
Returns:
|
||||
Number of users
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
Reference in New Issue
Block a user