Fix ruff lint, Pyright type ambiguity, and MCP tool SQL queries
- Fix 28 ruff errors: E501 line length, B904 raise-from, F401 unused import - Fix SQLAlchemy Row.count() ambiguity with tuple indexing (Pyright) - Replace composite column notation with accessor functions in MCP tools (topocentric/equatorial/pass_event are C-level base types, not composites) - Fix satellite_pass: use time window (start + end) not count parameter to match predict_passes(tle, observer, start_ts, end_ts, min_el) signature
This commit is contained in:
parent
a40ae9437d
commit
33787e03da
@ -3,7 +3,6 @@
|
|||||||
Run: docker compose exec api-dev python -m orrery_search.ingest
|
Run: docker compose exec api-dev python -m orrery_search.ingest
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,8 @@ logger = logging.getLogger("orrery_search")
|
|||||||
SYSTEM_PROMPT = (
|
SYSTEM_PROMPT = (
|
||||||
"/no_think\n"
|
"/no_think\n"
|
||||||
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
|
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
|
||||||
"celestial mechanics types and functions. Answer questions using the provided context "
|
"celestial mechanics types and functions. "
|
||||||
|
"Answer questions using the provided context "
|
||||||
"from the documentation.\n\n"
|
"from the documentation.\n\n"
|
||||||
"Key domain knowledge:\n"
|
"Key domain knowledge:\n"
|
||||||
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
|
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
|
||||||
@ -125,7 +126,10 @@ async def _chat_completion_stream_gpu(
|
|||||||
{"role": "system", "content": SYSTEM_PROMPT},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
"content": (
|
||||||
|
f"Documentation context:\n\n{context}"
|
||||||
|
f"\n\n---\n\nQuestion: {question}"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -196,7 +200,10 @@ async def _chat_completion_stream_anthropic(
|
|||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
"content": (
|
||||||
|
f"Documentation context:\n\n{context}"
|
||||||
|
f"\n\n---\n\nQuestion: {question}"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
) as stream:
|
) as stream:
|
||||||
@ -232,7 +239,10 @@ async def _chat_completion(context: str, question: str) -> str:
|
|||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
"content": (
|
||||||
|
f"Documentation context:\n\n{context}"
|
||||||
|
f"\n\n---\n\nQuestion: {question}"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -243,7 +253,10 @@ async def _chat_completion(context: str, question: str) -> str:
|
|||||||
{"role": "system", "content": SYSTEM_PROMPT},
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
"content": (
|
||||||
|
f"Documentation context:\n\n{context}"
|
||||||
|
f"\n\n---\n\nQuestion: {question}"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
@ -284,11 +297,16 @@ async def ask_orrery(
|
|||||||
context, sources = await _build_context(question)
|
context, sources = await _build_context(question)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to build RAG context")
|
logger.exception("Failed to build RAG context")
|
||||||
raise RuntimeError("Search failed while building context")
|
raise RuntimeError(
|
||||||
|
"Search failed while building context"
|
||||||
|
) from None
|
||||||
|
|
||||||
if not context:
|
if not context:
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"answer": "No relevant documents found in the pg_orrery docs for this question.",
|
"answer": (
|
||||||
|
"No relevant documents found in the "
|
||||||
|
"pg_orrery docs for this question."
|
||||||
|
),
|
||||||
"sources": [],
|
"sources": [],
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -296,10 +314,18 @@ async def ask_orrery(
|
|||||||
answer = await _chat_completion(context, question)
|
answer = await _chat_completion(context, question)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
|
logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
|
||||||
raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}")
|
status = exc.response.status_code
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Chat model returned HTTP {status}"
|
||||||
|
) from None
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout)
|
logger.warning(
|
||||||
raise RuntimeError("Chat model timed out — try a simpler question")
|
"Chat completion timed out (limit: %.0fs)",
|
||||||
|
settings.chat_timeout,
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Chat model timed out — try a simpler question"
|
||||||
|
) from None
|
||||||
|
|
||||||
result: dict = {"answer": answer}
|
result: dict = {"answer": answer}
|
||||||
if include_sources:
|
if include_sources:
|
||||||
|
|||||||
@ -63,7 +63,10 @@ def _validate_sql(sql: str) -> None:
|
|||||||
raise ValueError("Only SELECT statements are allowed")
|
raise ValueError("Only SELECT statements are allowed")
|
||||||
|
|
||||||
if _FORBIDDEN_RE.search(stripped):
|
if _FORBIDDEN_RE.search(stripped):
|
||||||
raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)")
|
raise ValueError(
|
||||||
|
"Statement contains forbidden keywords"
|
||||||
|
" (INSERT/UPDATE/DELETE/DDL)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_row(row: asyncpg.Record) -> dict:
|
def _serialize_row(row: asyncpg.Record) -> dict:
|
||||||
@ -98,12 +101,12 @@ async def run_query(sql: str) -> str:
|
|||||||
_validate_sql(sql)
|
_validate_sql(sql)
|
||||||
|
|
||||||
pool = await _get_pool()
|
pool = await _get_pool()
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn, conn.transaction(readonly=True):
|
||||||
async with conn.transaction(readonly=True):
|
timeout_ms = int(settings.run_query_timeout * 1000)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'"
|
f"SET LOCAL statement_timeout = '{timeout_ms}'"
|
||||||
)
|
)
|
||||||
rows = await conn.fetch(sql)
|
rows = await conn.fetch(sql)
|
||||||
|
|
||||||
if len(rows) > MAX_ROWS:
|
if len(rows) > MAX_ROWS:
|
||||||
rows = rows[:MAX_ROWS]
|
rows = rows[:MAX_ROWS]
|
||||||
@ -142,8 +145,13 @@ async def planet_position(
|
|||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
(t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s,
|
topo_azimuth(t) AS az_deg,
|
||||||
(e).ra_hours, (e).dec_degrees, (e).distance_km
|
topo_elevation(t) AS el_deg,
|
||||||
|
topo_range(t) AS range_km,
|
||||||
|
topo_range_rate(t) AS range_rate_km_s,
|
||||||
|
eq_ra(e) AS ra_hours,
|
||||||
|
eq_dec(e) AS dec_deg,
|
||||||
|
eq_distance(e) AS distance_km
|
||||||
FROM (
|
FROM (
|
||||||
SELECT
|
SELECT
|
||||||
planet_observe({body_id}, {obs}, {ts}) AS t,
|
planet_observe({body_id}, {obs}, {ts}) AS t,
|
||||||
@ -178,11 +186,11 @@ async def sky_survey(
|
|||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
body_name,
|
body_name,
|
||||||
(topo).azimuth_deg AS az,
|
topo_azimuth(topo) AS az,
|
||||||
(topo).elevation_deg AS el,
|
topo_elevation(topo) AS el,
|
||||||
(topo).range_km AS range_km,
|
topo_range(topo) AS range_km,
|
||||||
(eq).ra_hours AS ra,
|
eq_ra(eq) AS ra,
|
||||||
(eq).dec_degrees AS dec
|
eq_dec(eq) AS dec
|
||||||
FROM (
|
FROM (
|
||||||
SELECT 'Sun' AS body_name,
|
SELECT 'Sun' AS body_name,
|
||||||
sun_observe({obs}, {ts}) AS topo,
|
sun_observe({obs}, {ts}) AS topo,
|
||||||
@ -192,9 +200,14 @@ async def sky_survey(
|
|||||||
moon_observe({obs}, {ts}),
|
moon_observe({obs}, {ts}),
|
||||||
moon_equatorial({ts})
|
moon_equatorial({ts})
|
||||||
UNION ALL
|
UNION ALL
|
||||||
SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']),
|
SELECT unnest(ARRAY[
|
||||||
planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}),
|
'Mercury','Venus','Mars','Jupiter',
|
||||||
planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
|
'Saturn','Uranus','Neptune']),
|
||||||
|
planet_observe(
|
||||||
|
unnest(ARRAY[1,2,4,5,6,7,8]),
|
||||||
|
{obs}, {ts}),
|
||||||
|
planet_equatorial(
|
||||||
|
unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
|
||||||
) survey
|
) survey
|
||||||
ORDER BY el DESC
|
ORDER BY el DESC
|
||||||
"""
|
"""
|
||||||
@ -210,11 +223,13 @@ async def satellite_pass(
|
|||||||
longitude: float,
|
longitude: float,
|
||||||
altitude: float = 0.0,
|
altitude: float = 0.0,
|
||||||
time: str = "NOW()",
|
time: str = "NOW()",
|
||||||
count: int = 5,
|
hours: int = 48,
|
||||||
|
min_elevation: float = 0.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Predict upcoming satellite passes over an observer location.
|
"""Predict upcoming satellite passes over an observer location.
|
||||||
|
|
||||||
Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
|
Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
|
||||||
|
Searches a time window from the start time forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tle_line1: First line of the TLE (69 chars, starts with "1 ")
|
tle_line1: First line of the TLE (69 chars, starts with "1 ")
|
||||||
@ -223,9 +238,10 @@ async def satellite_pass(
|
|||||||
longitude: Observer longitude in degrees
|
longitude: Observer longitude in degrees
|
||||||
altitude: Observer altitude in meters (default 0)
|
altitude: Observer altitude in meters (default 0)
|
||||||
time: Start time for pass search (default "NOW()")
|
time: Start time for pass search (default "NOW()")
|
||||||
count: Number of passes to predict (default 5, max 20)
|
hours: Hours to search forward from start time (default 48)
|
||||||
|
min_elevation: Minimum peak elevation in degrees (default 0)
|
||||||
"""
|
"""
|
||||||
count = max(1, min(count, 20))
|
hours = max(1, min(hours, 168))
|
||||||
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
||||||
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
||||||
|
|
||||||
@ -235,12 +251,17 @@ async def satellite_pass(
|
|||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
(p).aos_time, (p).los_time,
|
pass_aos_time(p) AS aos_time,
|
||||||
(p).max_elevation_deg,
|
pass_los_time(p) AS los_time,
|
||||||
(p).aos_azimuth_deg, (p).los_azimuth_deg
|
pass_max_elevation(p) AS max_el_deg,
|
||||||
|
pass_aos_azimuth(p) AS aos_az_deg,
|
||||||
|
pass_los_azimuth(p) AS los_az_deg,
|
||||||
|
pass_duration(p) AS duration
|
||||||
FROM predict_passes(
|
FROM predict_passes(
|
||||||
tle_from_lines('{safe_l1}', '{safe_l2}'),
|
tle_from_lines('{safe_l1}', '{safe_l2}'),
|
||||||
{obs}, {ts}, {count}
|
{obs}, {ts},
|
||||||
|
{ts} + interval '{hours} hours',
|
||||||
|
{min_elevation}
|
||||||
) AS p
|
) AS p
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -71,7 +71,8 @@ async def get_document(slug: str) -> str:
|
|||||||
Returns title, body text, metadata, and URL.
|
Returns title, body text, metadata, and URL.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
slug: Document identifier (from search results, e.g. "guides/tracking-satellites")
|
slug: Document identifier from search results
|
||||||
|
(e.g. "guides/tracking-satellites")
|
||||||
"""
|
"""
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@ -122,10 +123,17 @@ async def list_content(
|
|||||||
.order_by(Document.section)
|
.order_by(Document.section)
|
||||||
)
|
)
|
||||||
groups = [
|
groups = [
|
||||||
{"section": row.section, "content_type": row.content_type, "count": row.count}
|
{
|
||||||
|
"section": row.section,
|
||||||
|
"content_type": row.content_type,
|
||||||
|
"count": row.count,
|
||||||
|
}
|
||||||
for row in result
|
for row in result
|
||||||
]
|
]
|
||||||
return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)})
|
total = sum(g["count"] for g in groups)
|
||||||
|
return json.dumps(
|
||||||
|
{"summary": groups, "total": total}
|
||||||
|
)
|
||||||
|
|
||||||
limit = max(1, min(limit, 500))
|
limit = max(1, min(limit, 500))
|
||||||
stmt = (
|
stmt = (
|
||||||
@ -197,8 +205,9 @@ def orrery_expert(topic: str = "general") -> str:
|
|||||||
)
|
)
|
||||||
topic_context = {
|
topic_context = {
|
||||||
"general": (
|
"general": (
|
||||||
"Help with any pg_orrery topic — satellite tracking, planetary observation, "
|
"Help with any pg_orrery topic — satellite tracking, "
|
||||||
"rise/set prediction, constellation identification, or Lagrange points."
|
"planetary observation, rise/set prediction, "
|
||||||
|
"constellation identification, or Lagrange points."
|
||||||
),
|
),
|
||||||
"satellites": (
|
"satellites": (
|
||||||
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
|
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
|
||||||
@ -206,11 +215,13 @@ def orrery_expert(topic: str = "general") -> str:
|
|||||||
),
|
),
|
||||||
"planets": (
|
"planets": (
|
||||||
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
|
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
|
||||||
"planet/sun/moon observation, equatorial coordinates, and apparent positions."
|
"planet/sun/moon observation, equatorial coordinates, "
|
||||||
|
"and apparent positions."
|
||||||
),
|
),
|
||||||
"navigation": (
|
"navigation": (
|
||||||
"Focus on observational astronomy: rise/set prediction, twilight computation, "
|
"Focus on observational astronomy: rise/set prediction, "
|
||||||
"constellation identification, lunar phase, planet magnitude, and refraction."
|
"twilight computation, constellation identification, "
|
||||||
|
"lunar phase, planet magnitude, and refraction."
|
||||||
),
|
),
|
||||||
"transfers": (
|
"transfers": (
|
||||||
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "
|
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "
|
||||||
|
|||||||
@ -48,7 +48,10 @@ async def chat(req: ChatRequest):
|
|||||||
context, sources = await _build_context(req.question)
|
context, sources = await _build_context(req.question)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Chat context build failed")
|
logger.exception("Chat context build failed")
|
||||||
raise HTTPException(status_code=502, detail="Search service unavailable")
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Search service unavailable",
|
||||||
|
) from None
|
||||||
|
|
||||||
if not context:
|
if not context:
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
@ -62,7 +65,10 @@ async def chat(req: ChatRequest):
|
|||||||
answer = await _chat_completion(context, req.question)
|
answer = await _chat_completion(context, req.question)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Chat completion failed")
|
logger.exception("Chat completion failed")
|
||||||
raise HTTPException(status_code=502, detail="Chat completion unavailable")
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Chat completion unavailable",
|
||||||
|
) from None
|
||||||
|
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
answer=answer,
|
answer=answer,
|
||||||
@ -105,7 +111,12 @@ async def chat_stream(req: ChatStreamRequest):
|
|||||||
async def generate():
|
async def generate():
|
||||||
question = req.question
|
question = req.question
|
||||||
if req.page and req.page.title:
|
if req.page and req.page.title:
|
||||||
page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})'
|
title = req.page.title
|
||||||
|
path = req.page.path
|
||||||
|
page_context = (
|
||||||
|
f'[The user is currently reading: '
|
||||||
|
f'"{title}" ({path})'
|
||||||
|
)
|
||||||
if req.page.description:
|
if req.page.description:
|
||||||
page_context += f"\nDocument description: {req.page.description}"
|
page_context += f"\nDocument description: {req.page.description}"
|
||||||
page_context += "]\n\n"
|
page_context += "]\n\n"
|
||||||
@ -127,7 +138,8 @@ async def chat_stream(req: ChatStreamRequest):
|
|||||||
{
|
{
|
||||||
"text": "I couldn't find any relevant documentation "
|
"text": "I couldn't find any relevant documentation "
|
||||||
"for that question. Try asking about satellite tracking, "
|
"for that question. Try asking about satellite tracking, "
|
||||||
"planetary observation, rise/set prediction, or other pg_orrery functions."
|
"planetary observation, rise/set prediction, "
|
||||||
|
"or other pg_orrery functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
yield _sse_event("done", {})
|
yield _sse_event("done", {})
|
||||||
@ -136,7 +148,13 @@ async def chat_stream(req: ChatStreamRequest):
|
|||||||
n = len(sources)
|
n = len(sources)
|
||||||
yield _sse_event(
|
yield _sse_event(
|
||||||
"status",
|
"status",
|
||||||
{"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"},
|
{
|
||||||
|
"text": (
|
||||||
|
f"Found {n} relevant "
|
||||||
|
f"page{'s' if n != 1 else ''}, "
|
||||||
|
"generating answer\u2026"
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
yield _sse_event("sources", sources)
|
yield _sse_event("sources", sources)
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,9 @@ router = APIRouter()
|
|||||||
async def search(
|
async def search(
|
||||||
q: str = Query(..., min_length=1, max_length=500),
|
q: str = Query(..., min_length=1, max_length=500),
|
||||||
content_type: str | None = Query(None, description="Filter by content type"),
|
content_type: str | None = Query(None, description="Filter by content type"),
|
||||||
section: str | None = Query(None, max_length=200, description="Section prefix filter"),
|
section: str | None = Query(
|
||||||
|
None, max_length=200, description="Section prefix filter"
|
||||||
|
),
|
||||||
limit: int = Query(10, ge=1, le=50),
|
limit: int = Query(10, ge=1, le=50),
|
||||||
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
|
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
@ -67,8 +69,8 @@ async def list_sections(
|
|||||||
)
|
)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
sections = [
|
sections = [
|
||||||
SectionCount(section=row.section, count=row.count)
|
SectionCount(section=row[0], count=row[1])
|
||||||
for row in result
|
for row in result
|
||||||
if row.section
|
if row[0]
|
||||||
]
|
]
|
||||||
return SectionsResponse(sections=sections)
|
return SectionsResponse(sections=sections)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user