refactor(db): use shared pooled DatabaseClient singleton instead of per-call instances
- Replace get_db_client() creating new DatabaseClient on every call with a module-level singleton initialized once at startup via init_db_client() - Add init_db_client() and close_db_client() lifecycle functions called from FastAPI lifespan handler - Migrate all DatabaseClient methods from legacy self.connect()/self.conn to pooled self.get_conn() context manager for thread-safe connection reuse - Pool is properly torn down on application shutdown Closes leeworks-agents/SPARC#7 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+152
-165
@@ -221,8 +221,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Cached message dict if found, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
query = """
|
||||
@@ -245,10 +243,11 @@ class DatabaseClient:
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
@@ -276,33 +275,32 @@ class DatabaseClient:
|
||||
Returns:
|
||||
The ID of the inserted record
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
|
||||
message_id = cursor.fetchone()[0]
|
||||
self.conn.commit()
|
||||
message_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
return message_id
|
||||
|
||||
@@ -324,8 +322,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||
params = []
|
||||
|
||||
@@ -340,9 +336,10 @@ class DatabaseClient:
|
||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||
params.extend([limit, offset])
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_analytics(self, days: int = 30) -> Dict:
|
||||
"""Get analytics on message usage.
|
||||
@@ -353,53 +350,52 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Dictionary with analytics data
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
# Patent Cache Methods
|
||||
|
||||
@@ -650,25 +646,23 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Created user dict or None if email exists
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
self.conn.rollback()
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||
@@ -681,23 +675,22 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict if authenticated, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
||||
"""Get a user by ID.
|
||||
@@ -708,15 +701,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
||||
"""Get a user by email.
|
||||
@@ -727,15 +719,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
||||
"""Get all users (admin only).
|
||||
@@ -747,19 +738,18 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of user dicts
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
||||
"""Update a user's role (admin only).
|
||||
@@ -771,20 +761,19 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Updated user dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
@@ -796,12 +785,11 @@ class DatabaseClient:
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
@@ -810,8 +798,7 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Number of users
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
Reference in New Issue
Block a user