forked from 0xWheatyz/SPARC
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c317632edb |
+16
-21
@@ -5,13 +5,10 @@ to provide company performance estimation based on patent portfolios.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from SPARC import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.serp_api import SERP
|
||||
@@ -55,13 +52,13 @@ class CompanyAnalyzer:
|
||||
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
||||
cached_ids = self.db.get_cached_serp_query(query_hash)
|
||||
if cached_ids is not None:
|
||||
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
|
||||
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
|
||||
patents = Patents(patents=[
|
||||
Patent(patent_id=pid, pdf_link="")
|
||||
for pid in cached_ids
|
||||
])
|
||||
else:
|
||||
logger.info("Retrieving patents for %s...", company_name)
|
||||
print(f"Retrieving patents for {company_name}...")
|
||||
patents = SERP.query(company_name)
|
||||
# Cache the SERP results
|
||||
if patents.patents:
|
||||
@@ -69,13 +66,12 @@ class CompanyAnalyzer:
|
||||
company_name=company_name,
|
||||
query_hash=query_hash,
|
||||
patent_ids=[p.patent_id for p in patents.patents],
|
||||
ttl_hours=config.serp_cache_ttl_hours,
|
||||
)
|
||||
|
||||
if not patents.patents:
|
||||
return f"No patents found for {company_name}"
|
||||
|
||||
logger.info("Found %d patents. Processing...", len(patents.patents))
|
||||
print(f"Found {len(patents.patents)} patents. Processing...")
|
||||
|
||||
# Download, parse, and minimize patents in parallel
|
||||
processed_patents = []
|
||||
@@ -91,12 +87,12 @@ class CompanyAnalyzer:
|
||||
if result:
|
||||
processed_patents.append(result)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to process %s: %s", patent.patent_id, e)
|
||||
print(f"Warning: Failed to process {patent.patent_id}: {e}")
|
||||
|
||||
if not processed_patents:
|
||||
return f"Failed to process any patents for {company_name}"
|
||||
|
||||
logger.info("Analyzing portfolio with LLM...")
|
||||
print("Analyzing portfolio with LLM...")
|
||||
|
||||
# Analyze the full portfolio with LLM
|
||||
analysis = self.llm_analyzer.analyze_patent_portfolio(
|
||||
@@ -126,7 +122,6 @@ class CompanyAnalyzer:
|
||||
FileNotFoundError: If the patent PDF is not found at the expected path.
|
||||
"""
|
||||
import os
|
||||
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
|
||||
|
||||
patent_path = f"patents/{patent_id}.pdf"
|
||||
|
||||
@@ -188,7 +183,7 @@ class CompanyAnalyzer:
|
||||
|
||||
return {"patent_id": patent.patent_id, "content": minimized_content}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to process %s: %s", patent.patent_id, e)
|
||||
print(f"Warning: Failed to process {patent.patent_id}: {e}")
|
||||
return None
|
||||
|
||||
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
|
||||
@@ -259,7 +254,7 @@ class CompanyAnalyzer:
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
logger.info("Starting batch analysis of %d companies...", total)
|
||||
print(f"Starting batch analysis of {total} companies...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_company = {
|
||||
@@ -276,8 +271,8 @@ class CompanyAnalyzer:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
status = "OK" if result.success else "FAIL"
|
||||
logger.info("[%d/%d] %s %s", completed, total, status, company)
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{completed}/{total}] {status} {company}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(company, completed, total)
|
||||
@@ -292,12 +287,12 @@ class CompanyAnalyzer:
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
|
||||
print(f"[{completed}/{total}] ✗ {company}: {e}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
@@ -323,20 +318,20 @@ class CompanyAnalyzer:
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
logger.info("Starting sequential analysis of %d companies...", total)
|
||||
print(f"Starting sequential analysis of {total} companies...")
|
||||
|
||||
for idx, company in enumerate(companies, 1):
|
||||
logger.info("[%d/%d] Analyzing %s...", idx, total, company)
|
||||
print(f"\n[{idx}/{total}] Analyzing {company}...")
|
||||
result = self._analyze_company_safe(company)
|
||||
results.append(result)
|
||||
|
||||
status = "OK" if result.success else "FAIL"
|
||||
logger.info("[%d/%d] %s %s", idx, total, status, company)
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{idx}/{total}] {status} {company}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
|
||||
+6
-44
@@ -21,13 +21,11 @@ from SPARC.auth import (
|
||||
TokenResponse,
|
||||
UserResponse,
|
||||
check_jwt_secret,
|
||||
close_db_client,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
get_current_admin,
|
||||
get_current_user,
|
||||
get_db_client,
|
||||
init_db_client,
|
||||
)
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
@@ -77,13 +75,6 @@ class JobStatus(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PaginatedJobsResponse(BaseModel):
|
||||
"""Paginated response for job listings."""
|
||||
|
||||
items: list["JobStatus"]
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
@@ -164,7 +155,6 @@ async def lifespan(app: FastAPI):
|
||||
"""Initialize resources on startup, clean up on shutdown."""
|
||||
global _analyzer
|
||||
check_jwt_secret()
|
||||
init_db_client()
|
||||
_analyzer = CompanyAnalyzer()
|
||||
# Mark any jobs that were running/pending before the restart as failed
|
||||
from SPARC.database import DatabaseClient
|
||||
@@ -177,9 +167,8 @@ async def lifespan(app: FastAPI):
|
||||
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
||||
_db.close()
|
||||
yield
|
||||
# Cleanup
|
||||
# Cleanup if needed
|
||||
_analyzer = None
|
||||
close_db_client()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
@@ -584,51 +573,24 @@ async def get_job_status(
|
||||
return _job_row_to_status(job_row)
|
||||
|
||||
|
||||
@app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
|
||||
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
|
||||
async def list_jobs(
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status: pending, running, completed, failed"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
||||
cursor: Annotated[
|
||||
str | None,
|
||||
Query(description="Opaque cursor from a previous response's next_cursor field"),
|
||||
] = None,
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""List analysis jobs with cursor-based pagination.
|
||||
|
||||
Pass ``limit`` to control page size. The response includes a ``next_cursor``
|
||||
field; pass it back as the ``cursor`` query parameter to fetch the next page.
|
||||
When ``next_cursor`` is ``null``, there are no more results.
|
||||
|
||||
Existing clients that use only ``limit`` (without ``cursor``) continue to
|
||||
work without modification.
|
||||
"""List all analysis jobs.
|
||||
|
||||
Args:
|
||||
status: Optional filter by job status
|
||||
limit: Maximum number of jobs to return (default 10, max 100)
|
||||
cursor: Opaque pagination cursor from a previous response
|
||||
|
||||
Returns:
|
||||
Paginated list of job statuses
|
||||
List of job statuses
|
||||
"""
|
||||
db = _get_job_db()
|
||||
# Fetch one extra to determine if there is a next page
|
||||
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
|
||||
|
||||
has_next = len(job_rows) > limit
|
||||
if has_next:
|
||||
job_rows = job_rows[:limit]
|
||||
|
||||
items = [_job_row_to_status(row) for row in job_rows]
|
||||
|
||||
next_cursor = None
|
||||
if has_next and job_rows:
|
||||
last = job_rows[-1]
|
||||
created = last["created_at"]
|
||||
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
|
||||
next_cursor = f"{ts}|{last['job_id']}"
|
||||
|
||||
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
|
||||
job_rows = db.list_jobs(status=status, limit=limit)
|
||||
return [_job_row_to_status(row) for row in job_rows]
|
||||
|
||||
+4
-29
@@ -146,36 +146,11 @@ def decode_token(token: str) -> Optional[TokenPayload]:
|
||||
return None
|
||||
|
||||
|
||||
# Shared database client singleton, initialized at startup via init_db_client()
|
||||
_db_client: DatabaseClient | None = None
|
||||
|
||||
|
||||
def init_db_client() -> None:
|
||||
"""Initialize the shared database client. Call once at app startup."""
|
||||
global _db_client
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
|
||||
|
||||
def close_db_client() -> None:
|
||||
"""Close the shared database client. Call at app shutdown."""
|
||||
global _db_client
|
||||
if _db_client:
|
||||
_db_client.close()
|
||||
_db_client = None
|
||||
|
||||
|
||||
def get_db_client() -> DatabaseClient:
|
||||
"""Get the shared pooled database client for auth operations.
|
||||
|
||||
Returns the module-level singleton DatabaseClient. If not yet initialized
|
||||
(e.g., during tests), creates a new instance as a fallback.
|
||||
"""
|
||||
global _db_client
|
||||
if _db_client is None:
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
return _db_client
|
||||
"""Get database client for auth operations."""
|
||||
client = DatabaseClient(config.database_url)
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
|
||||
@@ -2,20 +2,12 @@
|
||||
|
||||
Loads environment variables from .env file for API keys and other secrets.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Logging configuration
|
||||
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# SerpAPI key for patent search
|
||||
api_key = os.getenv("API_KEY")
|
||||
|
||||
@@ -39,12 +31,6 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
|
||||
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
|
||||
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
|
||||
|
||||
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
|
||||
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
|
||||
|
||||
# SERP cache TTL in hours (how long cached search results are considered fresh)
|
||||
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
|
||||
|
||||
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
|
||||
# This ensures OpenAPI docs work correctly when accessed via the proxy
|
||||
root_path = os.getenv("ROOT_PATH", "")
|
||||
|
||||
+174
-186
@@ -222,6 +222,8 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Cached message dict if found, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
query = """
|
||||
@@ -244,11 +246,10 @@ class DatabaseClient:
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
@@ -276,32 +277,33 @@ class DatabaseClient:
|
||||
Returns:
|
||||
The ID of the inserted record
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
|
||||
message_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
message_id = cursor.fetchone()[0]
|
||||
self.conn.commit()
|
||||
|
||||
return message_id
|
||||
|
||||
@@ -323,6 +325,8 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||
params = []
|
||||
|
||||
@@ -337,10 +341,9 @@ class DatabaseClient:
|
||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||
params.extend([limit, offset])
|
||||
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_analytics(self, days: int = 30) -> Dict:
|
||||
"""Get analytics on message usage.
|
||||
@@ -351,52 +354,53 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Dictionary with analytics data
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
self.connect()
|
||||
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
# Patent Cache Methods
|
||||
|
||||
@@ -568,45 +572,20 @@ class DatabaseClient:
|
||||
self,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
cursor: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""List jobs with optional status filter and cursor-based pagination.
|
||||
|
||||
Args:
|
||||
status: Optional status filter (pending, running, completed, failed).
|
||||
limit: Maximum number of jobs to return.
|
||||
cursor: Opaque cursor (``created_at|job_id``) from a previous
|
||||
response. When provided, only jobs older than the cursor are
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
List of job dicts ordered by created_at descending.
|
||||
"""
|
||||
conditions: list[str] = []
|
||||
params: list = []
|
||||
|
||||
if status:
|
||||
conditions.append("status = %s")
|
||||
params.append(status)
|
||||
|
||||
if cursor:
|
||||
try:
|
||||
ts_str, cursor_job_id = cursor.rsplit("|", 1)
|
||||
conditions.append("(created_at, job_id) < (%s, %s)")
|
||||
params.extend([ts_str, cursor_job_id])
|
||||
except ValueError:
|
||||
pass # Ignore malformed cursors; return from start
|
||||
|
||||
"""List jobs, optionally filtered by status."""
|
||||
query = "SELECT * FROM jobs"
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
query += " ORDER BY created_at DESC, job_id DESC LIMIT %s"
|
||||
params: list = []
|
||||
if status:
|
||||
query += " WHERE status = %s"
|
||||
params.append(status)
|
||||
query += " ORDER BY created_at DESC LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
cur.execute(query, params)
|
||||
return [dict(row) for row in cur.fetchall()]
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def mark_stale_jobs_failed(self) -> int:
|
||||
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
|
||||
@@ -672,23 +651,25 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Created user dict or None if email exists
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
try:
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
return dict(user) if user else None
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
self.conn.rollback()
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||
@@ -701,22 +682,23 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict if authenticated, None otherwise
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.connect()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
||||
"""Get a user by ID.
|
||||
@@ -727,14 +709,15 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
||||
"""Get a user by email.
|
||||
@@ -745,14 +728,15 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
||||
"""Get all users (admin only).
|
||||
@@ -764,18 +748,19 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of user dicts
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
||||
"""Update a user's role (admin only).
|
||||
@@ -787,19 +772,20 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Updated user dict or None
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
return dict(user) if user else None
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
@@ -811,11 +797,12 @@ class DatabaseClient:
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
self.conn.commit()
|
||||
return deleted
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
@@ -824,7 +811,8 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Number of users
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
+7
-6
@@ -1,6 +1,5 @@
|
||||
"""LLM integration for patent analysis using OpenRouter."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -8,8 +7,6 @@ from openai import OpenAI
|
||||
from SPARC import config
|
||||
from SPARC.database import DatabaseClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMAnalyzer:
|
||||
"""Handles LLM-based analysis of patent content."""
|
||||
@@ -25,7 +22,7 @@ class LLMAnalyzer:
|
||||
"""
|
||||
self.test_mode = test_mode
|
||||
self.use_cache = use_cache if use_cache is not None else config.use_cache
|
||||
self.model = config.model
|
||||
self.model = "anthropic/claude-3.5-sonnet"
|
||||
|
||||
# Always initialize database client for storage and caching
|
||||
self.db_client = DatabaseClient(config.database_url)
|
||||
@@ -64,7 +61,11 @@ Patent Content:
|
||||
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
|
||||
|
||||
if self.test_mode:
|
||||
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
|
||||
print("=" * 80)
|
||||
print("TEST MODE - Prompt that would be sent to LLM:")
|
||||
print("=" * 80)
|
||||
print(prompt)
|
||||
print("=" * 80)
|
||||
return "[TEST MODE - No API call made]"
|
||||
|
||||
# Check cache first
|
||||
@@ -166,7 +167,7 @@ Patent Portfolio:
|
||||
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
|
||||
|
||||
if self.test_mode:
|
||||
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
|
||||
print(prompt)
|
||||
return "[TEST MODE]"
|
||||
|
||||
metadata = {
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
@dataclass
|
||||
class Patent:
|
||||
patent_id: str
|
||||
patent_id: int
|
||||
pdf_link: str
|
||||
pdf_path: str | None = None
|
||||
summary: dict | None = None
|
||||
|
||||
Reference in New Issue
Block a user