refactor(db): use shared pooled DatabaseClient singleton instead of per-call instances

- Replace get_db_client() creating new DatabaseClient on every call with a
  module-level singleton initialized once at startup via init_db_client()
- Add init_db_client() and close_db_client() lifecycle functions called
  from FastAPI lifespan handler
- Migrate all DatabaseClient methods from legacy self.connect()/self.conn
  to pooled self.get_conn() context manager for thread-safe connection reuse
- Pool is properly torn down on application shutdown

Closes leeworks-agents/SPARC#7

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-03-26 04:15:03 +00:00
parent 6105ba7793
commit 349bb4d073
3 changed files with 187 additions and 171 deletions
+6 -2
View File
@@ -16,11 +16,13 @@ from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import ( from SPARC.auth import (
TokenResponse, TokenResponse,
UserResponse, UserResponse,
close_db_client,
create_tokens, create_tokens,
decode_token, decode_token,
get_current_admin, get_current_admin,
get_current_user, get_current_user,
get_db_client, get_db_client,
init_db_client,
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -148,12 +150,14 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Initialize resources on startup.""" """Initialize resources on startup, clean up on shutdown."""
global _analyzer global _analyzer
init_db_client()
_analyzer = CompanyAnalyzer() _analyzer = CompanyAnalyzer()
yield yield
# Cleanup if needed # Cleanup
_analyzer = None _analyzer = None
close_db_client()
app = FastAPI( app = FastAPI(
+29 -4
View File
@@ -132,11 +132,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient: def get_db_client() -> DatabaseClient:
"""Get database client for auth operations.""" """Get the shared pooled database client for auth operations.
client = DatabaseClient(config.database_url)
client.connect() Returns the module-level singleton DatabaseClient. If not yet initialized
return client (e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user( async def get_current_user(
+28 -41
View File
@@ -201,8 +201,6 @@ class DatabaseClient:
Returns: Returns:
Cached message dict if found, None otherwise Cached message dict if found, None otherwise
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
query = """ query = """
@@ -225,7 +223,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1" query += " ORDER BY timestamp DESC LIMIT 1"
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params) cursor.execute(query, params)
result = cursor.fetchone() result = cursor.fetchone()
return dict(result) if result else None return dict(result) if result else None
@@ -256,11 +255,10 @@ class DatabaseClient:
Returns: Returns:
The ID of the inserted record The ID of the inserted record
""" """
self.connect()
prompt_hash = self.hash_prompt(prompt) prompt_hash = self.hash_prompt(prompt)
with self.conn.cursor() as cursor: with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO llm_messages INSERT INTO llm_messages
@@ -282,7 +280,7 @@ class DatabaseClient:
) )
message_id = cursor.fetchone()[0] message_id = cursor.fetchone()[0]
self.conn.commit() conn.commit()
return message_id return message_id
@@ -304,8 +302,6 @@ class DatabaseClient:
Returns: Returns:
List of message dictionaries List of message dictionaries
""" """
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1" query = "SELECT * FROM llm_messages WHERE 1=1"
params = [] params = []
@@ -320,7 +316,8 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset]) params.extend([limit, offset])
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params) cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -333,9 +330,8 @@ class DatabaseClient:
Returns: Returns:
Dictionary with analytics data Dictionary with analytics data
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages # Total messages
cursor.execute( cursor.execute(
""" """
@@ -505,12 +501,11 @@ class DatabaseClient:
Returns: Returns:
Created user dict or None if email exists Created user dict or None if email exists
""" """
self.connect()
password_hash = self.hash_password(password) password_hash = self.hash_password(password)
try: try:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO users (email, password_hash, role) INSERT INTO users (email, password_hash, role)
@@ -520,10 +515,9 @@ class DatabaseClient:
(email, password_hash, role), (email, password_hash, role),
) )
user = cursor.fetchone() user = cursor.fetchone()
self.conn.commit() conn.commit()
return dict(user) if user else None return dict(user) if user else None
except psycopg2.errors.UniqueViolation: except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]: def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -536,9 +530,8 @@ class DatabaseClient:
Returns: Returns:
User dict if authenticated, None otherwise User dict if authenticated, None otherwise
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT * FROM users WHERE email = %s", "SELECT * FROM users WHERE email = %s",
(email,), (email,),
@@ -563,9 +556,8 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s", "SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,), (user_id,),
@@ -582,9 +574,8 @@ class DatabaseClient:
Returns: Returns:
User dict or None User dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s", "SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,), (email,),
@@ -602,9 +593,8 @@ class DatabaseClient:
Returns: Returns:
List of user dicts List of user dicts
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
SELECT id, email, role, created_at SELECT id, email, role, created_at
@@ -626,9 +616,8 @@ class DatabaseClient:
Returns: Returns:
Updated user dict or None Updated user dict or None
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute( cursor.execute(
""" """
UPDATE users UPDATE users
@@ -639,7 +628,7 @@ class DatabaseClient:
(role, user_id), (role, user_id),
) )
user = cursor.fetchone() user = cursor.fetchone()
self.conn.commit() conn.commit()
return dict(user) if user else None return dict(user) if user else None
def delete_user(self, user_id: int) -> bool: def delete_user(self, user_id: int) -> bool:
@@ -651,12 +640,11 @@ class DatabaseClient:
Returns: Returns:
True if deleted True if deleted
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor() as cursor:
with self.conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
self.conn.commit() conn.commit()
return deleted return deleted
def get_user_count(self) -> int: def get_user_count(self) -> int:
@@ -665,8 +653,7 @@ class DatabaseClient:
Returns: Returns:
Number of users Number of users
""" """
self.connect() with self.get_conn() as conn:
with conn.cursor() as cursor:
with self.conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users") cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0] return cursor.fetchone()[0]