forked from 0xWheatyz/SPARC
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f33447eef8 | |||
| 55c131cb32 | |||
| fbb72fe2a5 | |||
| e484baaf5f | |||
| 069f1c343c | |||
| d366443b38 |
@@ -9,7 +9,43 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install system dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||||
|
git checkout ${{ gitea.sha }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip3 install --break-system-packages -r requirements.txt ruff
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
ruff check SPARC/ tests/
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
DATABASE_URL: "sqlite://"
|
||||||
|
API_KEY: "test-key"
|
||||||
|
OPENROUTER_API_KEY: "test-key"
|
||||||
|
JWT_SECRET: "test-secret-for-ci"
|
||||||
|
APP_ENV: "development"
|
||||||
|
run: |
|
||||||
|
python3 -m pytest tests/ -v --tb=short -x
|
||||||
|
|
||||||
build-api:
|
build-api:
|
||||||
|
needs: test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -81,6 +117,7 @@ jobs:
|
|||||||
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
||||||
|
|
||||||
build-frontend:
|
build-frontend:
|
||||||
|
needs: test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
name: Test and Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install system dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||||
|
git checkout ${{ gitea.sha }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip3 install --break-system-packages -r requirements.txt ruff
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
ruff check SPARC/ tests/
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
DATABASE_URL: "sqlite://"
|
||||||
|
API_KEY: "test-key"
|
||||||
|
OPENROUTER_API_KEY: "test-key"
|
||||||
|
JWT_SECRET: "test-secret-for-ci"
|
||||||
|
APP_ENV: "development"
|
||||||
|
run: |
|
||||||
|
python3 -m pytest tests/ -v --tb=short -x
|
||||||
+3
-2
@@ -1,3 +1,4 @@
|
|||||||
from .types import Patents, Patent
|
from .types import Patent as Patent
|
||||||
|
from .types import Patents as Patents
|
||||||
|
|
||||||
all = ["Patents", "Patent"]
|
__all__ = ["Patents", "Patent"]
|
||||||
|
|||||||
+2
-2
@@ -13,9 +13,9 @@ from SPARC import config
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from SPARC.database import DatabaseClient
|
from SPARC.database import DatabaseClient
|
||||||
from SPARC.serp_api import SERP
|
|
||||||
from SPARC.llm import LLMAnalyzer
|
from SPARC.llm import LLMAnalyzer
|
||||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.serp_api import SERP
|
||||||
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
class CompanyAnalyzer:
|
class CompanyAnalyzer:
|
||||||
|
|||||||
+62
-1
@@ -21,11 +21,13 @@ from SPARC.auth import (
|
|||||||
TokenResponse,
|
TokenResponse,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
check_jwt_secret,
|
check_jwt_secret,
|
||||||
|
close_db_client,
|
||||||
create_tokens,
|
create_tokens,
|
||||||
decode_token,
|
decode_token,
|
||||||
get_current_admin,
|
get_current_admin,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_db_client,
|
get_db_client,
|
||||||
|
init_db_client,
|
||||||
)
|
)
|
||||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||||
|
|
||||||
@@ -155,6 +157,7 @@ async def lifespan(app: FastAPI):
|
|||||||
"""Initialize resources on startup, clean up on shutdown."""
|
"""Initialize resources on startup, clean up on shutdown."""
|
||||||
global _analyzer
|
global _analyzer
|
||||||
check_jwt_secret()
|
check_jwt_secret()
|
||||||
|
init_db_client()
|
||||||
_analyzer = CompanyAnalyzer()
|
_analyzer = CompanyAnalyzer()
|
||||||
# Mark any jobs that were running/pending before the restart as failed
|
# Mark any jobs that were running/pending before the restart as failed
|
||||||
from SPARC.database import DatabaseClient
|
from SPARC.database import DatabaseClient
|
||||||
@@ -166,9 +169,13 @@ async def lifespan(app: FastAPI):
|
|||||||
import logging
|
import logging
|
||||||
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
||||||
_db.close()
|
_db.close()
|
||||||
|
# Start scheduled analysis if tracked companies are configured
|
||||||
|
from SPARC.scheduler import start_scheduler
|
||||||
|
start_scheduler()
|
||||||
yield
|
yield
|
||||||
# Cleanup if needed
|
# Cleanup
|
||||||
_analyzer = None
|
_analyzer = None
|
||||||
|
close_db_client()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -365,6 +372,60 @@ async def delete_user(
|
|||||||
return {"message": "User deleted"}
|
return {"message": "User deleted"}
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Tracked Companies Endpoints ==============
|
||||||
|
|
||||||
|
|
||||||
|
class TrackCompanyRequest(BaseModel):
|
||||||
|
"""Request to add a company to tracking."""
|
||||||
|
|
||||||
|
company_name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/admin/tracked", tags=["Admin"])
|
||||||
|
async def list_tracked_companies(
|
||||||
|
_: UserResponse = Depends(get_current_admin),
|
||||||
|
):
|
||||||
|
"""List all tracked companies (admin only)."""
|
||||||
|
db = get_db_client()
|
||||||
|
return db.list_tracked_companies()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/admin/tracked", tags=["Admin"])
|
||||||
|
async def add_tracked_company(
|
||||||
|
request: TrackCompanyRequest,
|
||||||
|
_: UserResponse = Depends(get_current_admin),
|
||||||
|
):
|
||||||
|
"""Add a company to the tracked list (admin only)."""
|
||||||
|
db = get_db_client()
|
||||||
|
result = db.add_tracked_company(request.company_name)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=409, detail="Company already tracked")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
|
||||||
|
async def remove_tracked_company(
|
||||||
|
company_name: str,
|
||||||
|
_: UserResponse = Depends(get_current_admin),
|
||||||
|
):
|
||||||
|
"""Remove a company from the tracked list (admin only)."""
|
||||||
|
db = get_db_client()
|
||||||
|
removed = db.remove_tracked_company(company_name)
|
||||||
|
if not removed:
|
||||||
|
raise HTTPException(status_code=404, detail="Company not found in tracking list")
|
||||||
|
return {"message": f"Stopped tracking {company_name}"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/admin/alerts", tags=["Admin"])
|
||||||
|
async def list_alerts(
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
_: UserResponse = Depends(get_current_admin),
|
||||||
|
):
|
||||||
|
"""List recent alerts from scheduled analysis (admin only)."""
|
||||||
|
db = get_db_client()
|
||||||
|
return db.list_alerts(limit=limit)
|
||||||
|
|
||||||
|
|
||||||
# ============== Analytics Endpoint ==============
|
# ============== Analytics Endpoint ==============
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+29
-4
@@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Shared database client singleton, initialized at startup via init_db_client()
|
||||||
|
_db_client: DatabaseClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_db_client() -> None:
|
||||||
|
"""Initialize the shared database client. Call once at app startup."""
|
||||||
|
global _db_client
|
||||||
|
_db_client = DatabaseClient(config.database_url)
|
||||||
|
_db_client.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def close_db_client() -> None:
|
||||||
|
"""Close the shared database client. Call at app shutdown."""
|
||||||
|
global _db_client
|
||||||
|
if _db_client:
|
||||||
|
_db_client.close()
|
||||||
|
_db_client = None
|
||||||
|
|
||||||
|
|
||||||
def get_db_client() -> DatabaseClient:
|
def get_db_client() -> DatabaseClient:
|
||||||
"""Get database client for auth operations."""
|
"""Get the shared pooled database client for auth operations.
|
||||||
client = DatabaseClient(config.database_url)
|
|
||||||
client.connect()
|
Returns the module-level singleton DatabaseClient. If not yet initialized
|
||||||
return client
|
(e.g., during tests), creates a new instance as a fallback.
|
||||||
|
"""
|
||||||
|
global _db_client
|
||||||
|
if _db_client is None:
|
||||||
|
_db_client = DatabaseClient(config.database_url)
|
||||||
|
_db_client.connect()
|
||||||
|
return _db_client
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
|
|||||||
+142
-47
@@ -1,14 +1,15 @@
|
|||||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import psycopg2
|
|
||||||
from psycopg2.pool import ThreadedConnectionPool
|
|
||||||
from psycopg2.extras import RealDictCursor
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
from psycopg2.pool import ThreadedConnectionPool
|
||||||
|
|
||||||
|
|
||||||
class DatabaseClient:
|
class DatabaseClient:
|
||||||
@@ -191,6 +192,35 @@ class DatabaseClient:
|
|||||||
ON jobs(status)
|
ON jobs(status)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Create tracked companies table for scheduled analysis
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS tracked_companies (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
company_name VARCHAR(255) UNIQUE NOT NULL,
|
||||||
|
last_patent_count INTEGER DEFAULT 0,
|
||||||
|
last_analysis_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create alerts table for significant changes
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS alerts (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
company_name VARCHAR(255) NOT NULL,
|
||||||
|
alert_type VARCHAR(50) NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
old_value NUMERIC,
|
||||||
|
new_value NUMERIC,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_alerts_company
|
||||||
|
ON alerts(company_name)
|
||||||
|
""")
|
||||||
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -221,8 +251,6 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Cached message dict if found, None otherwise
|
Cached message dict if found, None otherwise
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
prompt_hash = self.hash_prompt(prompt)
|
prompt_hash = self.hash_prompt(prompt)
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
@@ -245,7 +273,8 @@ class DatabaseClient:
|
|||||||
|
|
||||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||||
|
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return dict(result) if result else None
|
return dict(result) if result else None
|
||||||
@@ -276,11 +305,10 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
The ID of the inserted record
|
The ID of the inserted record
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
prompt_hash = self.hash_prompt(prompt)
|
prompt_hash = self.hash_prompt(prompt)
|
||||||
|
|
||||||
with self.conn.cursor() as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO llm_messages
|
INSERT INTO llm_messages
|
||||||
@@ -302,7 +330,7 @@ class DatabaseClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
message_id = cursor.fetchone()[0]
|
message_id = cursor.fetchone()[0]
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
@@ -324,8 +352,6 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of message dictionaries
|
List of message dictionaries
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
@@ -340,7 +366,8 @@ class DatabaseClient:
|
|||||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
@@ -353,9 +380,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with analytics data
|
Dictionary with analytics data
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
# Total messages
|
# Total messages
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -650,12 +676,11 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Created user dict or None if email exists
|
Created user dict or None if email exists
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
password_hash = self.hash_password(password)
|
password_hash = self.hash_password(password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO users (email, password_hash, role)
|
INSERT INTO users (email, password_hash, role)
|
||||||
@@ -665,10 +690,9 @@ class DatabaseClient:
|
|||||||
(email, password_hash, role),
|
(email, password_hash, role),
|
||||||
)
|
)
|
||||||
user = cursor.fetchone()
|
user = cursor.fetchone()
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return dict(user) if user else None
|
return dict(user) if user else None
|
||||||
except psycopg2.errors.UniqueViolation:
|
except psycopg2.errors.UniqueViolation:
|
||||||
self.conn.rollback()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||||
@@ -681,9 +705,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict if authenticated, None otherwise
|
User dict if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM users WHERE email = %s",
|
"SELECT * FROM users WHERE email = %s",
|
||||||
(email,),
|
(email,),
|
||||||
@@ -708,9 +731,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict or None
|
User dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
@@ -727,9 +749,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict or None
|
User dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||||
(email,),
|
(email,),
|
||||||
@@ -747,9 +768,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of user dicts
|
List of user dicts
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, email, role, created_at
|
SELECT id, email, role, created_at
|
||||||
@@ -771,9 +791,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Updated user dict or None
|
Updated user dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
@@ -784,7 +803,7 @@ class DatabaseClient:
|
|||||||
(role, user_id),
|
(role, user_id),
|
||||||
)
|
)
|
||||||
user = cursor.fetchone()
|
user = cursor.fetchone()
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return dict(user) if user else None
|
return dict(user) if user else None
|
||||||
|
|
||||||
def delete_user(self, user_id: int) -> bool:
|
def delete_user(self, user_id: int) -> bool:
|
||||||
@@ -796,12 +815,11 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
True if deleted
|
True if deleted
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||||
deleted = cursor.rowcount > 0
|
deleted = cursor.rowcount > 0
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
def get_user_count(self) -> int:
|
def get_user_count(self) -> int:
|
||||||
@@ -810,8 +828,85 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of users
|
Number of users
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("SELECT COUNT(*) FROM users")
|
cursor.execute("SELECT COUNT(*) FROM users")
|
||||||
return cursor.fetchone()[0]
|
return cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# Tracked Companies Methods
|
||||||
|
|
||||||
|
def add_tracked_company(self, company_name: str) -> Optional[Dict]:
|
||||||
|
"""Add a company to the tracking list."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *",
|
||||||
|
(company_name,),
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
conn.commit()
|
||||||
|
return dict(row) if row else None
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def remove_tracked_company(self, company_name: str) -> bool:
|
||||||
|
"""Remove a company from the tracking list."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)",
|
||||||
|
(company_name,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def list_tracked_companies(self) -> List[Dict]:
|
||||||
|
"""List all tracked companies."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name")
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def update_tracked_company(
|
||||||
|
self, company_name: str, patent_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Update the last analysis stats for a tracked company."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""UPDATE tracked_companies
|
||||||
|
SET last_patent_count = %s, last_analysis_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE LOWER(company_name) = LOWER(%s)""",
|
||||||
|
(patent_count, company_name),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def store_alert(
|
||||||
|
self,
|
||||||
|
company_name: str,
|
||||||
|
alert_type: str,
|
||||||
|
message: str,
|
||||||
|
old_value: float | None = None,
|
||||||
|
new_value: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record an alert for a significant change."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""INSERT INTO alerts (company_name, alert_type, message, old_value, new_value)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)""",
|
||||||
|
(company_name, alert_type, message, old_value, new_value),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def list_alerts(self, limit: int = 50) -> List[Dict]:
|
||||||
|
"""List recent alerts."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM alerts ORDER BY created_at DESC LIMIT %s",
|
||||||
|
(limit,),
|
||||||
|
)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"""Scheduled patent analysis for tracked companies.
|
||||||
|
|
||||||
|
Uses APScheduler to periodically re-analyze tracked companies and
|
||||||
|
detect significant changes in patent counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from SPARC import config
|
||||||
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
|
from SPARC.database import DatabaseClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configurable via environment variable (in hours, default 24)
|
||||||
|
SCHEDULE_INTERVAL_HOURS = int(os.getenv("SCHEDULE_INTERVAL_HOURS", "24"))
|
||||||
|
|
||||||
|
# Patent count change threshold (percentage) to trigger an alert
|
||||||
|
CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20"))
|
||||||
|
|
||||||
|
|
||||||
|
def run_scheduled_analysis() -> None:
|
||||||
|
"""Re-analyze all tracked companies and check for significant changes."""
|
||||||
|
db = DatabaseClient(config.database_url)
|
||||||
|
db.connect()
|
||||||
|
db.initialize_schema()
|
||||||
|
|
||||||
|
tracked = db.list_tracked_companies()
|
||||||
|
if not tracked:
|
||||||
|
logger.info("No tracked companies configured; skipping scheduled analysis")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Running scheduled analysis for %d tracked companies", len(tracked))
|
||||||
|
|
||||||
|
analyzer = CompanyAnalyzer(db_client=db)
|
||||||
|
|
||||||
|
for company_row in tracked:
|
||||||
|
name = company_row["company_name"]
|
||||||
|
old_count = company_row.get("last_patent_count", 0) or 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = analyzer._analyze_company_safe(name)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
new_count = result.patent_count
|
||||||
|
|
||||||
|
# Update tracking record
|
||||||
|
db.update_tracked_company(name, new_count)
|
||||||
|
|
||||||
|
# Check for significant change
|
||||||
|
if old_count > 0:
|
||||||
|
delta_pct = abs(new_count - old_count) / old_count * 100
|
||||||
|
if delta_pct >= CHANGE_THRESHOLD_PERCENT:
|
||||||
|
direction = "increased" if new_count > old_count else "decreased"
|
||||||
|
message = (
|
||||||
|
f"Patent count for {name} {direction} by {delta_pct:.0f}% "
|
||||||
|
f"({old_count} -> {new_count})"
|
||||||
|
)
|
||||||
|
logger.warning("ALERT: %s", message)
|
||||||
|
db.store_alert(
|
||||||
|
company_name=name,
|
||||||
|
alert_type="patent_count_change",
|
||||||
|
message=message,
|
||||||
|
old_value=old_count,
|
||||||
|
new_value=new_count,
|
||||||
|
)
|
||||||
|
elif new_count > 0:
|
||||||
|
# First analysis -- record baseline
|
||||||
|
logger.info("Baseline for %s: %d patents", name, new_count)
|
||||||
|
else:
|
||||||
|
logger.warning("Scheduled analysis failed for %s: %s", name, result.error)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error analyzing tracked company %s: %s", name, e)
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
logger.info("Scheduled analysis complete")
|
||||||
|
|
||||||
|
|
||||||
|
def start_scheduler() -> None:
|
||||||
|
"""Start the APScheduler background scheduler.
|
||||||
|
|
||||||
|
Safe to call at application startup. If apscheduler is not installed,
|
||||||
|
the function logs a warning and returns without starting anything.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"apscheduler not installed; scheduled analysis disabled. "
|
||||||
|
"Install with: pip install apscheduler"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
scheduler = BackgroundScheduler()
|
||||||
|
scheduler.add_job(
|
||||||
|
run_scheduled_analysis,
|
||||||
|
"interval",
|
||||||
|
hours=SCHEDULE_INTERVAL_HOURS,
|
||||||
|
id="scheduled_patent_analysis",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
scheduler.start()
|
||||||
|
logger.info(
|
||||||
|
"Scheduled patent analysis started (every %d hours, threshold %d%%)",
|
||||||
|
SCHEDULE_INTERVAL_HOURS,
|
||||||
|
CHANGE_THRESHOLD_PERCENT,
|
||||||
|
)
|
||||||
+8
-5
@@ -1,12 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import serpapi
|
|
||||||
from SPARC import config
|
|
||||||
import re
|
import re
|
||||||
import pdfplumber # pip install pdfplumber
|
|
||||||
import requests
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from SPARC.types import Patents, Patent
|
|
||||||
|
import pdfplumber # pip install pdfplumber
|
||||||
|
import requests
|
||||||
|
import serpapi
|
||||||
|
|
||||||
|
from SPARC import config
|
||||||
|
from SPARC.types import Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
class SERP:
|
class SERP:
|
||||||
def query(company: str, days_back: int = None) -> Patents:
|
def query(company: str, days_back: int = None) -> Patents:
|
||||||
|
|||||||
@@ -15,3 +15,4 @@ pandas
|
|||||||
bcrypt
|
bcrypt
|
||||||
PyJWT
|
PyJWT
|
||||||
slowapi
|
slowapi
|
||||||
|
apscheduler
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
[lint]
|
||||||
|
select = ["E", "F", "I"]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
]
|
||||||
|
|
||||||
|
[lint.per-file-ignores]
|
||||||
|
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
"""Tests for the high-level company analyzer orchestration."""
|
"""Tests for the high-level company analyzer orchestration."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch, call, MagicMock
|
|
||||||
from SPARC.analyzer import CompanyAnalyzer
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.types import BatchAnalysisResult, Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
|
|||||||
"""Test analyzer initialization with API key."""
|
"""Test analyzer initialization with API key."""
|
||||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||||
|
|
||||||
analyzer = CompanyAnalyzer(openrouter_api_key="test-key")
|
_analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
|
||||||
|
|
||||||
mock_llm.assert_called_once_with(api_key="test-key")
|
mock_llm.assert_called_once_with(api_key="test-key")
|
||||||
|
|
||||||
|
|||||||
+4
-3
@@ -1,12 +1,13 @@
|
|||||||
"""Tests for FastAPI web service endpoints."""
|
"""Tests for FastAPI web service endpoints."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from SPARC.api import app
|
from SPARC.api import app
|
||||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
+3
-1
@@ -1,7 +1,9 @@
|
|||||||
"""Tests for LLM analysis functionality."""
|
"""Tests for LLM analysis functionality."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, MagicMock, patch
|
|
||||||
from SPARC.llm import LLMAnalyzer
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""Tests for SERP API patent retrieval and parsing functionality."""
|
"""Tests for SERP API patent retrieval and parsing functionality."""
|
||||||
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from SPARC.serp_api import SERP
|
from SPARC.serp_api import SERP
|
||||||
from SPARC.types import Patent
|
from SPARC.types import Patent
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user