forked from 0xWheatyz/SPARC
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ecc2c37bcd | |||
| 55c131cb32 | |||
| fbb72fe2a5 | |||
| e484baaf5f | |||
| 069f1c343c | |||
| d366443b38 |
@@ -9,7 +9,43 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install system dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||
|
||||
- name: Checkout code
|
||||
shell: sh
|
||||
run: |
|
||||
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||
git checkout ${{ gitea.sha }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
pip3 install --break-system-packages -r requirements.txt ruff
|
||||
|
||||
- name: Run ruff linter
|
||||
shell: sh
|
||||
run: |
|
||||
ruff check SPARC/ tests/
|
||||
|
||||
- name: Run pytest
|
||||
shell: sh
|
||||
env:
|
||||
DATABASE_URL: "sqlite://"
|
||||
API_KEY: "test-key"
|
||||
OPENROUTER_API_KEY: "test-key"
|
||||
JWT_SECRET: "test-secret-for-ci"
|
||||
APP_ENV: "development"
|
||||
run: |
|
||||
python3 -m pytest tests/ -v --tb=short -x
|
||||
|
||||
build-api:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
@@ -81,6 +117,7 @@ jobs:
|
||||
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
||||
|
||||
build-frontend:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Test and Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install system dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||
|
||||
- name: Checkout code
|
||||
shell: sh
|
||||
run: |
|
||||
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||
git checkout ${{ gitea.sha }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
pip3 install --break-system-packages -r requirements.txt ruff
|
||||
|
||||
- name: Run ruff linter
|
||||
shell: sh
|
||||
run: |
|
||||
ruff check SPARC/ tests/
|
||||
|
||||
- name: Run pytest
|
||||
shell: sh
|
||||
env:
|
||||
DATABASE_URL: "sqlite://"
|
||||
API_KEY: "test-key"
|
||||
OPENROUTER_API_KEY: "test-key"
|
||||
JWT_SECRET: "test-secret-for-ci"
|
||||
APP_ENV: "development"
|
||||
run: |
|
||||
python3 -m pytest tests/ -v --tb=short -x
|
||||
+3
-2
@@ -1,3 +1,4 @@
|
||||
from .types import Patents, Patent
|
||||
from .types import Patent as Patent
|
||||
from .types import Patents as Patents
|
||||
|
||||
all = ["Patents", "Patent"]
|
||||
__all__ = ["Patents", "Patent"]
|
||||
|
||||
+23
-13
@@ -13,9 +13,9 @@ from SPARC import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
|
||||
|
||||
|
||||
class CompanyAnalyzer:
|
||||
@@ -108,12 +108,10 @@ class CompanyAnalyzer:
|
||||
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
|
||||
"""Analyze a single patent by ID.
|
||||
|
||||
Prerequisite:
|
||||
The patent PDF must already exist at ``patents/{patent_id}.pdf``
|
||||
before calling this method. PDFs are downloaded automatically when
|
||||
using the batch analysis pipeline (``analyze_company`` or the
|
||||
``/analyze/batch`` API endpoint). For standalone usage, download
|
||||
the PDF manually or call ``SERP.save_patents()`` first.
|
||||
If the patent PDF is not already on disk, this method attempts to
|
||||
download it automatically by looking up the PDF link in the database
|
||||
cache. If the link is not cached either, a ``FileNotFoundError`` is
|
||||
raised with instructions on how to obtain the PDF.
|
||||
|
||||
Args:
|
||||
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
|
||||
@@ -123,7 +121,7 @@ class CompanyAnalyzer:
|
||||
Analysis of the specific patent's innovation quality
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the patent PDF is not found at the expected path.
|
||||
FileNotFoundError: If the patent PDF cannot be found or downloaded.
|
||||
"""
|
||||
import os
|
||||
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
|
||||
@@ -131,10 +129,22 @@ class CompanyAnalyzer:
|
||||
patent_path = f"patents/{patent_id}.pdf"
|
||||
|
||||
if not os.path.exists(patent_path):
|
||||
raise FileNotFoundError(
|
||||
f"Patent PDF not found at '{patent_path}'. "
|
||||
f"Download the PDF first using SERP.save_patents() or the batch analysis pipeline."
|
||||
)
|
||||
# Attempt to download the PDF automatically from cached metadata
|
||||
cached = self.db.get_cached_patent(patent_id)
|
||||
pdf_link = cached.get("pdf_link") if cached else None
|
||||
|
||||
if pdf_link:
|
||||
logger.info("PDF not on disk; downloading %s from cached link", patent_id)
|
||||
patent = SERP.save_patents(
|
||||
Patent(patent_id=patent_id, pdf_link=pdf_link)
|
||||
)
|
||||
patent_path = patent.pdf_path
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Patent PDF not found at '{patent_path}' and no download link is "
|
||||
f"cached for '{patent_id}'. Run a company analysis first to populate "
|
||||
f"the cache, or call SERP.save_patents() with the patent's PDF link."
|
||||
)
|
||||
|
||||
try:
|
||||
sections = SERP.parse_patent_pdf(patent_path)
|
||||
|
||||
+37
-1
@@ -21,11 +21,13 @@ from SPARC.auth import (
|
||||
TokenResponse,
|
||||
UserResponse,
|
||||
check_jwt_secret,
|
||||
close_db_client,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
get_current_admin,
|
||||
get_current_user,
|
||||
get_db_client,
|
||||
init_db_client,
|
||||
)
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
@@ -155,6 +157,7 @@ async def lifespan(app: FastAPI):
|
||||
"""Initialize resources on startup, clean up on shutdown."""
|
||||
global _analyzer
|
||||
check_jwt_secret()
|
||||
init_db_client()
|
||||
_analyzer = CompanyAnalyzer()
|
||||
# Mark any jobs that were running/pending before the restart as failed
|
||||
from SPARC.database import DatabaseClient
|
||||
@@ -167,8 +170,9 @@ async def lifespan(app: FastAPI):
|
||||
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
||||
_db.close()
|
||||
yield
|
||||
# Cleanup if needed
|
||||
# Cleanup
|
||||
_analyzer = None
|
||||
close_db_client()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
@@ -425,6 +429,38 @@ async def analyze_company(
|
||||
return _convert_result(result)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/analyze/patent/{patent_id}",
|
||||
tags=["Analysis"],
|
||||
)
|
||||
async def analyze_single_patent(
|
||||
patent_id: str,
|
||||
company_name: str = Query(description="Company name for analysis context"),
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Analyze a single patent by its publication ID.
|
||||
|
||||
If the patent PDF is not already cached locally, the system will attempt
|
||||
to download it automatically from a previously cached link. If no link
|
||||
is available, a 404 error is returned.
|
||||
|
||||
Args:
|
||||
patent_id: Patent publication ID (e.g. "US-11234567-B2")
|
||||
company_name: Company name for analysis context
|
||||
|
||||
Returns:
|
||||
Analysis text for the patent
|
||||
"""
|
||||
if not _analyzer:
|
||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||
|
||||
try:
|
||||
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
|
||||
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@app.post(
|
||||
"/analyze/batch",
|
||||
response_model=BatchAnalysisResponse,
|
||||
|
||||
+29
-4
@@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
|
||||
return None
|
||||
|
||||
|
||||
# Shared database client singleton, initialized at startup via init_db_client()
|
||||
_db_client: DatabaseClient | None = None
|
||||
|
||||
|
||||
def init_db_client() -> None:
|
||||
"""Initialize the shared database client. Call once at app startup."""
|
||||
global _db_client
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
|
||||
|
||||
def close_db_client() -> None:
|
||||
"""Close the shared database client. Call at app shutdown."""
|
||||
global _db_client
|
||||
if _db_client:
|
||||
_db_client.close()
|
||||
_db_client = None
|
||||
|
||||
|
||||
def get_db_client() -> DatabaseClient:
|
||||
"""Get database client for auth operations."""
|
||||
client = DatabaseClient(config.database_url)
|
||||
client.connect()
|
||||
return client
|
||||
"""Get the shared pooled database client for auth operations.
|
||||
|
||||
Returns the module-level singleton DatabaseClient. If not yet initialized
|
||||
(e.g., during tests), creates a new instance as a fallback.
|
||||
"""
|
||||
global _db_client
|
||||
if _db_client is None:
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
return _db_client
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
|
||||
+159
-171
@@ -1,14 +1,15 @@
|
||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||
|
||||
import contextlib
|
||||
import psycopg2
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import bcrypt
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
@@ -221,8 +222,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Cached message dict if found, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
query = """
|
||||
@@ -245,10 +244,11 @@ class DatabaseClient:
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
@@ -276,33 +276,32 @@ class DatabaseClient:
|
||||
Returns:
|
||||
The ID of the inserted record
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
|
||||
message_id = cursor.fetchone()[0]
|
||||
self.conn.commit()
|
||||
message_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
return message_id
|
||||
|
||||
@@ -324,8 +323,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||
params = []
|
||||
|
||||
@@ -340,9 +337,10 @@ class DatabaseClient:
|
||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||
params.extend([limit, offset])
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_analytics(self, days: int = 30) -> Dict:
|
||||
"""Get analytics on message usage.
|
||||
@@ -353,53 +351,52 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Dictionary with analytics data
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
# Patent Cache Methods
|
||||
|
||||
@@ -650,25 +647,23 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Created user dict or None if email exists
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
self.conn.rollback()
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||
@@ -681,23 +676,22 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict if authenticated, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
||||
"""Get a user by ID.
|
||||
@@ -708,15 +702,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
||||
"""Get a user by email.
|
||||
@@ -727,15 +720,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
||||
"""Get all users (admin only).
|
||||
@@ -747,19 +739,18 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of user dicts
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
||||
"""Update a user's role (admin only).
|
||||
@@ -771,20 +762,19 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Updated user dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
@@ -796,12 +786,11 @@ class DatabaseClient:
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
@@ -810,8 +799,7 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Number of users
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
+8
-5
@@ -1,12 +1,15 @@
|
||||
import os
|
||||
import serpapi
|
||||
from SPARC import config
|
||||
import re
|
||||
import pdfplumber # pip install pdfplumber
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict
|
||||
from SPARC.types import Patents, Patent
|
||||
|
||||
import pdfplumber # pip install pdfplumber
|
||||
import requests
|
||||
import serpapi
|
||||
|
||||
from SPARC import config
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
|
||||
class SERP:
|
||||
def query(company: str, days_back: int = None) -> Patents:
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
[lint]
|
||||
select = ["E", "F", "I"]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for the high-level company analyzer orchestration."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, call, MagicMock
|
||||
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.types import BatchAnalysisResult, Patent, Patents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
|
||||
"""Test analyzer initialization with API key."""
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
analyzer = CompanyAnalyzer(openrouter_api_key="test-key")
|
||||
_analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
|
||||
|
||||
mock_llm.assert_called_once_with(api_key="test-key")
|
||||
|
||||
|
||||
+4
-3
@@ -1,12 +1,13 @@
|
||||
"""Tests for FastAPI web service endpoints."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
+3
-1
@@ -1,7 +1,9 @@
|
||||
"""Tests for LLM analysis functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Tests for SERP API patent retrieval and parsing functionality."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.types import Patent
|
||||
|
||||
|
||||
Reference in New Issue
Block a user