forked from 0xWheatyz/SPARC
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b4d712fc5 | |||
| 55c131cb32 | |||
| fbb72fe2a5 | |||
| e484baaf5f | |||
| 069f1c343c | |||
| d366443b38 |
@@ -9,7 +9,43 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install system dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||
|
||||
- name: Checkout code
|
||||
shell: sh
|
||||
run: |
|
||||
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||
git checkout ${{ gitea.sha }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
pip3 install --break-system-packages -r requirements.txt ruff
|
||||
|
||||
- name: Run ruff linter
|
||||
shell: sh
|
||||
run: |
|
||||
ruff check SPARC/ tests/
|
||||
|
||||
- name: Run pytest
|
||||
shell: sh
|
||||
env:
|
||||
DATABASE_URL: "sqlite://"
|
||||
API_KEY: "test-key"
|
||||
OPENROUTER_API_KEY: "test-key"
|
||||
JWT_SECRET: "test-secret-for-ci"
|
||||
APP_ENV: "development"
|
||||
run: |
|
||||
python3 -m pytest tests/ -v --tb=short -x
|
||||
|
||||
build-api:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
@@ -81,6 +117,7 @@ jobs:
|
||||
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
||||
|
||||
build-frontend:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Test and Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install system dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||
|
||||
- name: Checkout code
|
||||
shell: sh
|
||||
run: |
|
||||
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||
git checkout ${{ gitea.sha }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
pip3 install --break-system-packages -r requirements.txt ruff
|
||||
|
||||
- name: Run ruff linter
|
||||
shell: sh
|
||||
run: |
|
||||
ruff check SPARC/ tests/
|
||||
|
||||
- name: Run pytest
|
||||
shell: sh
|
||||
env:
|
||||
DATABASE_URL: "sqlite://"
|
||||
API_KEY: "test-key"
|
||||
OPENROUTER_API_KEY: "test-key"
|
||||
JWT_SECRET: "test-secret-for-ci"
|
||||
APP_ENV: "development"
|
||||
run: |
|
||||
python3 -m pytest tests/ -v --tb=short -x
|
||||
+3
-2
@@ -1,3 +1,4 @@
|
||||
from .types import Patents, Patent
|
||||
from .types import Patent as Patent
|
||||
from .types import Patents as Patents
|
||||
|
||||
all = ["Patents", "Patent"]
|
||||
__all__ = ["Patents", "Patent"]
|
||||
|
||||
+2
-2
@@ -13,9 +13,9 @@ from SPARC import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
|
||||
|
||||
|
||||
class CompanyAnalyzer:
|
||||
|
||||
+5
-1
@@ -21,11 +21,13 @@ from SPARC.auth import (
|
||||
TokenResponse,
|
||||
UserResponse,
|
||||
check_jwt_secret,
|
||||
close_db_client,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
get_current_admin,
|
||||
get_current_user,
|
||||
get_db_client,
|
||||
init_db_client,
|
||||
)
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
@@ -155,6 +157,7 @@ async def lifespan(app: FastAPI):
|
||||
"""Initialize resources on startup, clean up on shutdown."""
|
||||
global _analyzer
|
||||
check_jwt_secret()
|
||||
init_db_client()
|
||||
_analyzer = CompanyAnalyzer()
|
||||
# Mark any jobs that were running/pending before the restart as failed
|
||||
from SPARC.database import DatabaseClient
|
||||
@@ -167,8 +170,9 @@ async def lifespan(app: FastAPI):
|
||||
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
||||
_db.close()
|
||||
yield
|
||||
# Cleanup if needed
|
||||
# Cleanup
|
||||
_analyzer = None
|
||||
close_db_client()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
|
||||
+29
-4
@@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
|
||||
return None
|
||||
|
||||
|
||||
# Shared database client singleton, initialized at startup via init_db_client()
|
||||
_db_client: DatabaseClient | None = None
|
||||
|
||||
|
||||
def init_db_client() -> None:
|
||||
"""Initialize the shared database client. Call once at app startup."""
|
||||
global _db_client
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
|
||||
|
||||
def close_db_client() -> None:
|
||||
"""Close the shared database client. Call at app shutdown."""
|
||||
global _db_client
|
||||
if _db_client:
|
||||
_db_client.close()
|
||||
_db_client = None
|
||||
|
||||
|
||||
def get_db_client() -> DatabaseClient:
|
||||
"""Get database client for auth operations."""
|
||||
client = DatabaseClient(config.database_url)
|
||||
client.connect()
|
||||
return client
|
||||
"""Get the shared pooled database client for auth operations.
|
||||
|
||||
Returns the module-level singleton DatabaseClient. If not yet initialized
|
||||
(e.g., during tests), creates a new instance as a fallback.
|
||||
"""
|
||||
global _db_client
|
||||
if _db_client is None:
|
||||
_db_client = DatabaseClient(config.database_url)
|
||||
_db_client.connect()
|
||||
return _db_client
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
|
||||
+159
-171
@@ -1,14 +1,15 @@
|
||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||
|
||||
import contextlib
|
||||
import psycopg2
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import bcrypt
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
@@ -221,8 +222,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Cached message dict if found, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
query = """
|
||||
@@ -245,10 +244,11 @@ class DatabaseClient:
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
def store_message(
|
||||
self,
|
||||
@@ -276,33 +276,32 @@ class DatabaseClient:
|
||||
Returns:
|
||||
The ID of the inserted record
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
prompt_hash = self.hash_prompt(prompt)
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO llm_messages
|
||||
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
prompt,
|
||||
prompt_hash,
|
||||
response,
|
||||
company_name,
|
||||
analysis_type,
|
||||
model,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(token_usage) if token_usage else None,
|
||||
is_cached,
|
||||
),
|
||||
)
|
||||
|
||||
message_id = cursor.fetchone()[0]
|
||||
self.conn.commit()
|
||||
message_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
return message_id
|
||||
|
||||
@@ -324,8 +323,6 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||
params = []
|
||||
|
||||
@@ -340,9 +337,10 @@ class DatabaseClient:
|
||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||
params.extend([limit, offset])
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_analytics(self, days: int = 30) -> Dict:
|
||||
"""Get analytics on message usage.
|
||||
@@ -353,53 +351,52 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Dictionary with analytics data
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Total messages
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total_messages
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
total = cursor.fetchone()["total_messages"]
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
|
||||
# Messages by company
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT company_name, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY company_name
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_company = cursor.fetchall()
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
# Messages by type
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT analysis_type, COUNT(*) as count
|
||||
FROM llm_messages
|
||||
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
||||
GROUP BY analysis_type
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
(days,),
|
||||
)
|
||||
by_type = cursor.fetchall()
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
return {
|
||||
"total_messages": total,
|
||||
"by_company": [dict(row) for row in by_company],
|
||||
"by_type": [dict(row) for row in by_type],
|
||||
"period_days": days,
|
||||
}
|
||||
|
||||
# Patent Cache Methods
|
||||
|
||||
@@ -650,25 +647,23 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Created user dict or None if email exists
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
password_hash = self.hash_password(password)
|
||||
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (email, password_hash, role)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(email, password_hash, role),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
self.conn.rollback()
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||
@@ -681,23 +676,22 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict if authenticated, None otherwise
|
||||
"""
|
||||
self.connect()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
if user and self.verify_password(password, user["password_hash"]):
|
||||
return {
|
||||
"id": user["id"],
|
||||
"email": user["email"],
|
||||
"role": user["role"],
|
||||
"created_at": user["created_at"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
||||
"""Get a user by ID.
|
||||
@@ -708,15 +702,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[Dict]:
|
||||
"""Get a user by email.
|
||||
@@ -727,15 +720,14 @@ class DatabaseClient:
|
||||
Returns:
|
||||
User dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||
(email,),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
|
||||
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
|
||||
"""Get all users (admin only).
|
||||
@@ -747,19 +739,18 @@ class DatabaseClient:
|
||||
Returns:
|
||||
List of user dicts
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, email, role, created_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
|
||||
"""Update a user's role (admin only).
|
||||
@@ -771,20 +762,19 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Updated user dict or None
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET role = %s, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
RETURNING id, email, role, created_at
|
||||
""",
|
||||
(role, user_id),
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
conn.commit()
|
||||
return dict(user) if user else None
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
@@ -796,12 +786,11 @@ class DatabaseClient:
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
self.conn.commit()
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
@@ -810,8 +799,7 @@ class DatabaseClient:
|
||||
Returns:
|
||||
Number of users
|
||||
"""
|
||||
self.connect()
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
+21
-6
@@ -1,12 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
import serpapi
|
||||
from SPARC import config
|
||||
import re
|
||||
import pdfplumber # pip install pdfplumber
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict
|
||||
from SPARC.types import Patents, Patent
|
||||
|
||||
import pdfplumber # pip install pdfplumber
|
||||
import requests
|
||||
import serpapi
|
||||
|
||||
from SPARC import config
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SERP:
|
||||
def query(company: str, days_back: int = None) -> Patents:
|
||||
@@ -41,6 +47,7 @@ class SERP:
|
||||
"tbs": date_filter,
|
||||
"api_key": config.api_key,
|
||||
}
|
||||
logger.info("Querying Google Patents for '%s' (last %d days)", company, days_back)
|
||||
search = serpapi.search(params)
|
||||
# Convert results to Patent objects, skipping any without PDF links
|
||||
patent_ids = []
|
||||
@@ -49,8 +56,10 @@ class SERP:
|
||||
pdf_link = patent.get("pdf")
|
||||
if pdf_link:
|
||||
patent_ids.append(Patent(patent_id=patent["publication_number"], pdf_link=pdf_link, summary=None))
|
||||
# Patents without PDF links are skipped (see docstring for details)
|
||||
else:
|
||||
logger.debug("Skipping patent %s (no PDF link)", patent.get("publication_number", "unknown"))
|
||||
|
||||
logger.info("Found %d patents with PDF links for '%s'", len(patent_ids), company)
|
||||
return Patents(patents=patent_ids)
|
||||
|
||||
def save_patents(patent: Patent) -> Patent:
|
||||
@@ -67,9 +76,13 @@ class SERP:
|
||||
os.makedirs("patents", exist_ok=True)
|
||||
|
||||
if not (os.path.exists(pdf_path) and os.path.getsize(pdf_path) > 0):
|
||||
logger.info("Downloading PDF for %s", patent.patent_id)
|
||||
response = requests.get(patent.pdf_link)
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.debug("Saved %d bytes to %s", len(response.content), pdf_path)
|
||||
else:
|
||||
logger.debug("Using cached PDF for %s at %s", patent.patent_id, pdf_path)
|
||||
|
||||
patent.pdf_path = pdf_path
|
||||
return patent
|
||||
@@ -87,11 +100,13 @@ class SERP:
|
||||
Dictionary containing all extracted sections
|
||||
"""
|
||||
|
||||
logger.debug("Parsing patent PDF: %s", pdf_path)
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
# Extract all text
|
||||
full_text = ""
|
||||
for page in pdf.pages:
|
||||
full_text += page.extract_text() + "\n"
|
||||
logger.debug("Extracted text from %d pages (%d chars)", len(pdf.pages), len(full_text))
|
||||
|
||||
# Define section patterns (common in patents)
|
||||
sections = {
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
[lint]
|
||||
select = ["E", "F", "I"]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for the high-level company analyzer orchestration."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, call, MagicMock
|
||||
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.types import BatchAnalysisResult, Patent, Patents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
|
||||
"""Test analyzer initialization with API key."""
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
analyzer = CompanyAnalyzer(openrouter_api_key="test-key")
|
||||
_analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
|
||||
|
||||
mock_llm.assert_called_once_with(api_key="test-key")
|
||||
|
||||
|
||||
+4
-3
@@ -1,12 +1,13 @@
|
||||
"""Tests for FastAPI web service endpoints."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
+3
-1
@@ -1,7 +1,9 @@
|
||||
"""Tests for LLM analysis functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Tests for SERP API patent retrieval and parsing functionality."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.types import Patent
|
||||
|
||||
|
||||
Reference in New Issue
Block a user