Merge pull request 'refactor(db): shared pooled DatabaseClient singleton' (#30) from feature/db-client-pooling into main

This commit is contained in:
2026-03-26 07:02:46 +00:00
3 changed files with 186 additions and 170 deletions
+5 -1
View File
@@ -21,11 +21,13 @@ from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
close_db_client,
create_tokens,
decode_token,
get_current_admin,
get_current_user,
get_db_client,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -155,6 +157,7 @@ async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
@@ -167,8 +170,9 @@ async def lifespan(app: FastAPI):
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup if needed
# Cleanup
_analyzer = None
close_db_client()
app = FastAPI(
+29 -4
View File
@@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient:
"""Get database client for auth operations."""
client = DatabaseClient(config.database_url)
client.connect()
return client
"""Get the shared pooled database client for auth operations.
Returns the module-level singleton DatabaseClient. If not yet initialized
(e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user(
+28 -41
View File
@@ -221,8 +221,6 @@ class DatabaseClient:
Returns:
Cached message dict if found, None otherwise
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
query = """
@@ -245,7 +243,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1"
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
result = cursor.fetchone()
return dict(result) if result else None
@@ -276,11 +275,10 @@ class DatabaseClient:
Returns:
The ID of the inserted record
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
with self.conn.cursor() as cursor:
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO llm_messages
@@ -302,7 +300,7 @@ class DatabaseClient:
)
message_id = cursor.fetchone()[0]
self.conn.commit()
conn.commit()
return message_id
@@ -324,8 +322,6 @@ class DatabaseClient:
Returns:
List of message dictionaries
"""
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1"
params = []
@@ -340,7 +336,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset])
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
@@ -353,9 +350,8 @@ class DatabaseClient:
Returns:
Dictionary with analytics data
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
@@ -650,12 +646,11 @@ class DatabaseClient:
Returns:
Created user dict or None if email exists
"""
self.connect()
password_hash = self.hash_password(password)
try:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO users (email, password_hash, role)
@@ -665,10 +660,9 @@ class DatabaseClient:
(email, password_hash, role),
)
user = cursor.fetchone()
self.conn.commit()
conn.commit()
return dict(user) if user else None
except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -681,9 +675,8 @@ class DatabaseClient:
Returns:
User dict if authenticated, None otherwise
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
@@ -708,9 +701,8 @@ class DatabaseClient:
Returns:
User dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,),
@@ -727,9 +719,8 @@ class DatabaseClient:
Returns:
User dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,),
@@ -747,9 +738,8 @@ class DatabaseClient:
Returns:
List of user dicts
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, email, role, created_at
@@ -771,9 +761,8 @@ class DatabaseClient:
Returns:
Updated user dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
UPDATE users
@@ -784,7 +773,7 @@ class DatabaseClient:
(role, user_id),
)
user = cursor.fetchone()
self.conn.commit()
conn.commit()
return dict(user) if user else None
def delete_user(self, user_id: int) -> bool:
@@ -796,12 +785,11 @@ class DatabaseClient:
Returns:
True if deleted
"""
self.connect()
with self.conn.cursor() as cursor:
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0
self.conn.commit()
conn.commit()
return deleted
def get_user_count(self) -> int:
@@ -810,8 +798,7 @@ class DatabaseClient:
Returns:
Number of users
"""
self.connect()
with self.conn.cursor() as cursor:
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]