Fix ruff lint, Pyright type ambiguity, and MCP tool SQL queries

- Fix 28 ruff errors: E501 line length, B904 raise-from, F401 unused import
- Fix SQLAlchemy Row.count() ambiguity with tuple indexing (Pyright)
- Replace composite column notation with accessor functions in MCP tools
  (topocentric/equatorial/pass_event are C-level base types, not composites)
- Fix satellite_pass: use time window (start + end) not count parameter
  to match predict_passes(tle, observer, start_ts, end_ts, min_el) signature
This commit is contained in:
Ryan Malloy 2026-03-01 16:41:07 -07:00
parent a40ae9437d
commit 33787e03da
6 changed files with 128 additions and 51 deletions

View File

@ -3,7 +3,6 @@
Run: docker compose exec api-dev python -m orrery_search.ingest
"""
import asyncio
import sys
from pathlib import Path

View File

@ -17,7 +17,8 @@ logger = logging.getLogger("orrery_search")
SYSTEM_PROMPT = (
"/no_think\n"
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
"celestial mechanics types and functions. Answer questions using the provided context "
"celestial mechanics types and functions. "
"Answer questions using the provided context "
"from the documentation.\n\n"
"Key domain knowledge:\n"
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
@ -125,7 +126,10 @@ async def _chat_completion_stream_gpu(
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
"content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
},
]
@ -196,7 +200,10 @@ async def _chat_completion_stream_anthropic(
messages=[
{
"role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
"content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
},
],
) as stream:
@ -232,7 +239,10 @@ async def _chat_completion(context: str, question: str) -> str:
messages=[
{
"role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
"content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
},
],
)
@ -243,7 +253,10 @@ async def _chat_completion(context: str, question: str) -> str:
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
"content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
},
]
resp = await client.post(
@ -284,11 +297,16 @@ async def ask_orrery(
context, sources = await _build_context(question)
except Exception:
logger.exception("Failed to build RAG context")
raise RuntimeError("Search failed while building context")
raise RuntimeError(
"Search failed while building context"
) from None
if not context:
return json.dumps({
"answer": "No relevant documents found in the pg_orrery docs for this question.",
"answer": (
"No relevant documents found in the "
"pg_orrery docs for this question."
),
"sources": [],
})
@ -296,10 +314,18 @@ async def ask_orrery(
answer = await _chat_completion(context, question)
except httpx.HTTPStatusError as exc:
logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}")
status = exc.response.status_code
raise RuntimeError(
f"Chat model returned HTTP {status}"
) from None
except httpx.TimeoutException:
logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout)
raise RuntimeError("Chat model timed out — try a simpler question")
logger.warning(
"Chat completion timed out (limit: %.0fs)",
settings.chat_timeout,
)
raise RuntimeError(
"Chat model timed out — try a simpler question"
) from None
result: dict = {"answer": answer}
if include_sources:

View File

@ -63,7 +63,10 @@ def _validate_sql(sql: str) -> None:
raise ValueError("Only SELECT statements are allowed")
if _FORBIDDEN_RE.search(stripped):
raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)")
raise ValueError(
"Statement contains forbidden keywords"
" (INSERT/UPDATE/DELETE/DDL)"
)
def _serialize_row(row: asyncpg.Record) -> dict:
@ -98,12 +101,12 @@ async def run_query(sql: str) -> str:
_validate_sql(sql)
pool = await _get_pool()
async with pool.acquire() as conn:
async with conn.transaction(readonly=True):
await conn.execute(
f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'"
)
rows = await conn.fetch(sql)
async with pool.acquire() as conn, conn.transaction(readonly=True):
timeout_ms = int(settings.run_query_timeout * 1000)
await conn.execute(
f"SET LOCAL statement_timeout = '{timeout_ms}'"
)
rows = await conn.fetch(sql)
if len(rows) > MAX_ROWS:
rows = rows[:MAX_ROWS]
@ -142,8 +145,13 @@ async def planet_position(
sql = f"""
SELECT
(t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s,
(e).ra_hours, (e).dec_degrees, (e).distance_km
topo_azimuth(t) AS az_deg,
topo_elevation(t) AS el_deg,
topo_range(t) AS range_km,
topo_range_rate(t) AS range_rate_km_s,
eq_ra(e) AS ra_hours,
eq_dec(e) AS dec_deg,
eq_distance(e) AS distance_km
FROM (
SELECT
planet_observe({body_id}, {obs}, {ts}) AS t,
@ -178,11 +186,11 @@ async def sky_survey(
sql = f"""
SELECT
body_name,
(topo).azimuth_deg AS az,
(topo).elevation_deg AS el,
(topo).range_km AS range_km,
(eq).ra_hours AS ra,
(eq).dec_degrees AS dec
topo_azimuth(topo) AS az,
topo_elevation(topo) AS el,
topo_range(topo) AS range_km,
eq_ra(eq) AS ra,
eq_dec(eq) AS dec
FROM (
SELECT 'Sun' AS body_name,
sun_observe({obs}, {ts}) AS topo,
@ -192,9 +200,14 @@ async def sky_survey(
moon_observe({obs}, {ts}),
moon_equatorial({ts})
UNION ALL
SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']),
planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}),
planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
SELECT unnest(ARRAY[
'Mercury','Venus','Mars','Jupiter',
'Saturn','Uranus','Neptune']),
planet_observe(
unnest(ARRAY[1,2,4,5,6,7,8]),
{obs}, {ts}),
planet_equatorial(
unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
) survey
ORDER BY el DESC
"""
@ -210,11 +223,13 @@ async def satellite_pass(
longitude: float,
altitude: float = 0.0,
time: str = "NOW()",
count: int = 5,
hours: int = 48,
min_elevation: float = 0.0,
) -> str:
"""Predict upcoming satellite passes over an observer location.
Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
Searches a time window from the start time forward.
Args:
tle_line1: First line of the TLE (69 chars, starts with "1 ")
@ -223,9 +238,10 @@ async def satellite_pass(
longitude: Observer longitude in degrees
altitude: Observer altitude in meters (default 0)
time: Start time for pass search (default "NOW()")
count: Number of passes to predict (default 5, max 20)
hours: Hours to search forward from start time (default 48)
min_elevation: Minimum peak elevation in degrees (default 0)
"""
count = max(1, min(count, 20))
hours = max(1, min(hours, 168))
obs = f"'({latitude},{longitude},{altitude})'::observer"
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
@ -235,12 +251,17 @@ async def satellite_pass(
sql = f"""
SELECT
(p).aos_time, (p).los_time,
(p).max_elevation_deg,
(p).aos_azimuth_deg, (p).los_azimuth_deg
pass_aos_time(p) AS aos_time,
pass_los_time(p) AS los_time,
pass_max_elevation(p) AS max_el_deg,
pass_aos_azimuth(p) AS aos_az_deg,
pass_los_azimuth(p) AS los_az_deg,
pass_duration(p) AS duration
FROM predict_passes(
tle_from_lines('{safe_l1}', '{safe_l2}'),
{obs}, {ts}, {count}
{obs}, {ts},
{ts} + interval '{hours} hours',
{min_elevation}
) AS p
"""

View File

@ -71,7 +71,8 @@ async def get_document(slug: str) -> str:
Returns title, body text, metadata, and URL.
Args:
slug: Document identifier (from search results, e.g. "guides/tracking-satellites")
slug: Document identifier from search results
(e.g. "guides/tracking-satellites")
"""
async with async_session() as db:
result = await db.execute(
@ -122,10 +123,17 @@ async def list_content(
.order_by(Document.section)
)
groups = [
{"section": row.section, "content_type": row.content_type, "count": row.count}
{
"section": row.section,
"content_type": row.content_type,
"count": row.count,
}
for row in result
]
return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)})
total = sum(g["count"] for g in groups)
return json.dumps(
{"summary": groups, "total": total}
)
limit = max(1, min(limit, 500))
stmt = (
@ -197,8 +205,9 @@ def orrery_expert(topic: str = "general") -> str:
)
topic_context = {
"general": (
"Help with any pg_orrery topic — satellite tracking, planetary observation, "
"rise/set prediction, constellation identification, or Lagrange points."
"Help with any pg_orrery topic — satellite tracking, "
"planetary observation, rise/set prediction, "
"constellation identification, or Lagrange points."
),
"satellites": (
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
@ -206,11 +215,13 @@ def orrery_expert(topic: str = "general") -> str:
),
"planets": (
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
"planet/sun/moon observation, equatorial coordinates, and apparent positions."
"planet/sun/moon observation, equatorial coordinates, "
"and apparent positions."
),
"navigation": (
"Focus on observational astronomy: rise/set prediction, twilight computation, "
"constellation identification, lunar phase, planet magnitude, and refraction."
"Focus on observational astronomy: rise/set prediction, "
"twilight computation, constellation identification, "
"lunar phase, planet magnitude, and refraction."
),
"transfers": (
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "

View File

@ -48,7 +48,10 @@ async def chat(req: ChatRequest):
context, sources = await _build_context(req.question)
except Exception:
logger.exception("Chat context build failed")
raise HTTPException(status_code=502, detail="Search service unavailable")
raise HTTPException(
status_code=502,
detail="Search service unavailable",
) from None
if not context:
return ChatResponse(
@ -62,7 +65,10 @@ async def chat(req: ChatRequest):
answer = await _chat_completion(context, req.question)
except Exception:
logger.exception("Chat completion failed")
raise HTTPException(status_code=502, detail="Chat completion unavailable")
raise HTTPException(
status_code=502,
detail="Chat completion unavailable",
) from None
return ChatResponse(
answer=answer,
@ -105,7 +111,12 @@ async def chat_stream(req: ChatStreamRequest):
async def generate():
question = req.question
if req.page and req.page.title:
page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})'
title = req.page.title
path = req.page.path
page_context = (
f'[The user is currently reading: '
f'"{title}" ({path})'
)
if req.page.description:
page_context += f"\nDocument description: {req.page.description}"
page_context += "]\n\n"
@ -127,7 +138,8 @@ async def chat_stream(req: ChatStreamRequest):
{
"text": "I couldn't find any relevant documentation "
"for that question. Try asking about satellite tracking, "
"planetary observation, rise/set prediction, or other pg_orrery functions."
"planetary observation, rise/set prediction, "
"or other pg_orrery functions."
},
)
yield _sse_event("done", {})
@ -136,7 +148,13 @@ async def chat_stream(req: ChatStreamRequest):
n = len(sources)
yield _sse_event(
"status",
{"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"},
{
"text": (
f"Found {n} relevant "
f"page{'s' if n != 1 else ''}, "
"generating answer\u2026"
),
},
)
yield _sse_event("sources", sources)

View File

@ -19,7 +19,9 @@ router = APIRouter()
async def search(
q: str = Query(..., min_length=1, max_length=500),
content_type: str | None = Query(None, description="Filter by content type"),
section: str | None = Query(None, max_length=200, description="Section prefix filter"),
section: str | None = Query(
None, max_length=200, description="Section prefix filter"
),
limit: int = Query(10, ge=1, le=50),
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
db: AsyncSession = Depends(get_db),
@ -67,8 +69,8 @@ async def list_sections(
)
result = await db.execute(stmt)
sections = [
SectionCount(section=row.section, count=row.count)
SectionCount(section=row[0], count=row[1])
for row in result
if row.section
if row[0]
]
return SectionsResponse(sections=sections)