diff --git a/SPARC/api.py b/SPARC/api.py index 01b103c..a78c132 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -21,11 +21,13 @@ from SPARC.auth import ( TokenResponse, UserResponse, check_jwt_secret, + close_db_client, create_tokens, decode_token, get_current_admin, get_current_user, get_db_client, + init_db_client, ) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -155,6 +157,7 @@ async def lifespan(app: FastAPI): """Initialize resources on startup, clean up on shutdown.""" global _analyzer check_jwt_secret() + init_db_client() _analyzer = CompanyAnalyzer() # Mark any jobs that were running/pending before the restart as failed from SPARC.database import DatabaseClient @@ -167,8 +170,9 @@ async def lifespan(app: FastAPI): logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale) _db.close() yield - # Cleanup if needed + # Cleanup _analyzer = None + close_db_client() app = FastAPI( diff --git a/SPARC/auth.py b/SPARC/auth.py index d134ad8..890d286 100644 --- a/SPARC/auth.py +++ b/SPARC/auth.py @@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]: return None +# Shared database client singleton, initialized at startup via init_db_client() +_db_client: DatabaseClient | None = None + + +def init_db_client() -> None: + """Initialize the shared database client. Call once at app startup.""" + global _db_client + _db_client = DatabaseClient(config.database_url) + _db_client.connect() + + +def close_db_client() -> None: + """Close the shared database client. Call at app shutdown.""" + global _db_client + if _db_client: + _db_client.close() + _db_client = None + + def get_db_client() -> DatabaseClient: - """Get database client for auth operations.""" - client = DatabaseClient(config.database_url) - client.connect() - return client + """Get the shared pooled database client for auth operations. + + Returns the module-level singleton DatabaseClient. If not yet initialized + (e.g., during tests), creates a new instance as a fallback. + """ + global _db_client + if _db_client is None: + _db_client = DatabaseClient(config.database_url) + _db_client.connect() + return _db_client async def get_current_user( diff --git a/SPARC/database.py b/SPARC/database.py index cc55304..a22d8e9 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -221,8 +221,6 @@ class DatabaseClient: Returns: Cached message dict if found, None otherwise """ - self.connect() - prompt_hash = self.hash_prompt(prompt) query = """ @@ -245,10 +243,11 @@ class DatabaseClient: query += " ORDER BY timestamp DESC LIMIT 1" - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return dict(result) if result else None + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return dict(result) if result else None def store_message( self, @@ -276,33 +275,32 @@ class DatabaseClient: Returns: The ID of the inserted record """ - self.connect() - prompt_hash = self.hash_prompt(prompt) - with self.conn.cursor() as cursor: - cursor.execute( - """ - INSERT INTO llm_messages - (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) - RETURNING id - """, - ( - prompt, - prompt_hash, - response, - company_name, - analysis_type, - model, - json.dumps(metadata) if metadata else None, - json.dumps(token_usage) if token_usage else None, - is_cached, - ), - ) + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + INSERT INTO llm_messages + (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + RETURNING id + """, + ( + prompt, + prompt_hash, + response, + company_name, + analysis_type, + model, + json.dumps(metadata) if metadata else None, + json.dumps(token_usage) if token_usage else None, + is_cached, + ), + ) - message_id = cursor.fetchone()[0] - self.conn.commit() + message_id = cursor.fetchone()[0] + conn.commit() return message_id @@ -324,8 +322,6 @@ class DatabaseClient: Returns: List of message dictionaries """ - self.connect() - query = "SELECT * FROM llm_messages WHERE 1=1" params = [] @@ -340,9 +336,10 @@ class DatabaseClient: query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s" params.extend([limit, offset]) - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute(query, params) - return [dict(row) for row in cursor.fetchall()] + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(query, params) + return [dict(row) for row in cursor.fetchall()] def get_analytics(self, days: int = 30) -> Dict: """Get analytics on message usage. @@ -353,53 +350,52 @@ class DatabaseClient: Returns: Dictionary with analytics data """ - self.connect() + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + # Total messages + cursor.execute( + """ + SELECT COUNT(*) as total_messages + FROM llm_messages + WHERE timestamp >= NOW() - INTERVAL '%s days' + """, + (days,), + ) + total = cursor.fetchone()["total_messages"] - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - # Total messages - cursor.execute( - """ - SELECT COUNT(*) as total_messages - FROM llm_messages - WHERE timestamp >= NOW() - INTERVAL '%s days' - """, - (days,), - ) - total = cursor.fetchone()["total_messages"] + # Messages by company + cursor.execute( + """ + SELECT company_name, COUNT(*) as count + FROM llm_messages + WHERE timestamp >= NOW() - INTERVAL '%s days' + GROUP BY company_name + ORDER BY count DESC + LIMIT 10 + """, + (days,), + ) + by_company = cursor.fetchall() - # Messages by company - cursor.execute( - """ - SELECT company_name, COUNT(*) as count - FROM llm_messages - WHERE timestamp >= NOW() - INTERVAL '%s days' - GROUP BY company_name - ORDER BY count DESC - LIMIT 10 - """, - (days,), - ) - by_company = cursor.fetchall() + # Messages by type + cursor.execute( + """ + SELECT analysis_type, COUNT(*) as count + FROM llm_messages + WHERE timestamp >= NOW() - INTERVAL '%s days' + GROUP BY analysis_type + ORDER BY count DESC + """, + (days,), + ) + by_type = cursor.fetchall() - # Messages by type - cursor.execute( - """ - SELECT analysis_type, COUNT(*) as count - FROM llm_messages - WHERE timestamp >= NOW() - INTERVAL '%s days' - GROUP BY analysis_type - ORDER BY count DESC - """, - (days,), - ) - by_type = cursor.fetchall() - - return { - "total_messages": total, - "by_company": [dict(row) for row in by_company], - "by_type": [dict(row) for row in by_type], - "period_days": days, - } + return { + "total_messages": total, + "by_company": [dict(row) for row in by_company], + "by_type": [dict(row) for row in by_type], + "period_days": days, + } # Patent Cache Methods @@ -650,25 +646,23 @@ class DatabaseClient: Returns: Created user dict or None if email exists """ - self.connect() - password_hash = self.hash_password(password) try: - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - """ - INSERT INTO users (email, password_hash, role) - VALUES (%s, %s, %s) - RETURNING id, email, role, created_at - """, - (email, password_hash, role), - ) - user = cursor.fetchone() - self.conn.commit() + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + INSERT INTO users (email, password_hash, role) + VALUES (%s, %s, %s) + RETURNING id, email, role, created_at + """, + (email, password_hash, role), + ) + user = cursor.fetchone() + conn.commit() return dict(user) if user else None except psycopg2.errors.UniqueViolation: - self.conn.rollback() return None def authenticate_user(self, email: str, password: str) -> Optional[Dict]: @@ -681,23 +675,22 @@ class DatabaseClient: Returns: User dict if authenticated, None otherwise """ - self.connect() + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT * FROM users WHERE email = %s", + (email,), + ) + user = cursor.fetchone() - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - "SELECT * FROM users WHERE email = %s", - (email,), - ) - user = cursor.fetchone() - - if user and self.verify_password(password, user["password_hash"]): - return { - "id": user["id"], - "email": user["email"], - "role": user["role"], - "created_at": user["created_at"], - } - return None + if user and self.verify_password(password, user["password_hash"]): + return { + "id": user["id"], + "email": user["email"], + "role": user["role"], + "created_at": user["created_at"], + } + return None def get_user_by_id(self, user_id: int) -> Optional[Dict]: """Get a user by ID. @@ -708,15 +701,14 @@ class DatabaseClient: Returns: User dict or None """ - self.connect() - - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - "SELECT id, email, role, created_at FROM users WHERE id = %s", - (user_id,), - ) - user = cursor.fetchone() - return dict(user) if user else None + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT id, email, role, created_at FROM users WHERE id = %s", + (user_id,), + ) + user = cursor.fetchone() + return dict(user) if user else None def get_user_by_email(self, email: str) -> Optional[Dict]: """Get a user by email. @@ -727,15 +719,14 @@ class DatabaseClient: Returns: User dict or None """ - self.connect() - - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - "SELECT id, email, role, created_at FROM users WHERE email = %s", - (email,), - ) - user = cursor.fetchone() - return dict(user) if user else None + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + "SELECT id, email, role, created_at FROM users WHERE email = %s", + (email,), + ) + user = cursor.fetchone() + return dict(user) if user else None def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]: """Get all users (admin only). @@ -747,19 +738,18 @@ class DatabaseClient: Returns: List of user dicts """ - self.connect() - - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - """ - SELECT id, email, role, created_at - FROM users - ORDER BY created_at DESC - LIMIT %s OFFSET %s - """, - (limit, offset), - ) - return [dict(row) for row in cursor.fetchall()] + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT id, email, role, created_at + FROM users + ORDER BY created_at DESC + LIMIT %s OFFSET %s + """, + (limit, offset), + ) + return [dict(row) for row in cursor.fetchall()] def update_user_role(self, user_id: int, role: str) -> Optional[Dict]: """Update a user's role (admin only). @@ -771,20 +761,19 @@ class DatabaseClient: Returns: Updated user dict or None """ - self.connect() - - with self.conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute( - """ - UPDATE users - SET role = %s, updated_at = CURRENT_TIMESTAMP - WHERE id = %s - RETURNING id, email, role, created_at - """, - (role, user_id), - ) - user = cursor.fetchone() - self.conn.commit() + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + UPDATE users + SET role = %s, updated_at = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, email, role, created_at + """, + (role, user_id), + ) + user = cursor.fetchone() + conn.commit() return dict(user) if user else None def delete_user(self, user_id: int) -> bool: @@ -796,12 +785,11 @@ class DatabaseClient: Returns: True if deleted """ - self.connect() - - with self.conn.cursor() as cursor: - cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) - deleted = cursor.rowcount > 0 - self.conn.commit() + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute("DELETE FROM users WHERE id = %s", (user_id,)) + deleted = cursor.rowcount > 0 + conn.commit() return deleted def get_user_count(self) -> int: @@ -810,8 +798,7 @@ class DatabaseClient: Returns: Number of users """ - self.connect() - - with self.conn.cursor() as cursor: - cursor.execute("SELECT COUNT(*) FROM users") - return cursor.fetchone()[0] + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT COUNT(*) FROM users") + return cursor.fetchone()[0]