Fix ruff lint, Pyright type ambiguity, and MCP tool SQL queries
- Fix 28 ruff errors: E501 line length, B904 raise-from, F401 unused import - Fix SQLAlchemy Row.count() ambiguity with tuple indexing (Pyright) - Replace composite column notation with accessor functions in MCP tools (topocentric/equatorial/pass_event are C-level base types, not composites) - Fix satellite_pass: use time window (start + end) not count parameter to match predict_passes(tle, observer, start_ts, end_ts, min_el) signature
This commit is contained in:
parent
a40ae9437d
commit
33787e03da
@ -3,7 +3,6 @@
|
||||
Run: docker compose exec api-dev python -m orrery_search.ingest
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@ -17,7 +17,8 @@ logger = logging.getLogger("orrery_search")
|
||||
SYSTEM_PROMPT = (
|
||||
"/no_think\n"
|
||||
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
|
||||
"celestial mechanics types and functions. Answer questions using the provided context "
|
||||
"celestial mechanics types and functions. "
|
||||
"Answer questions using the provided context "
|
||||
"from the documentation.\n\n"
|
||||
"Key domain knowledge:\n"
|
||||
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
|
||||
@ -125,7 +126,10 @@ async def _chat_completion_stream_gpu(
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||
"content": (
|
||||
f"Documentation context:\n\n{context}"
|
||||
f"\n\n---\n\nQuestion: {question}"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
@ -196,7 +200,10 @@ async def _chat_completion_stream_anthropic(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||
"content": (
|
||||
f"Documentation context:\n\n{context}"
|
||||
f"\n\n---\n\nQuestion: {question}"
|
||||
),
|
||||
},
|
||||
],
|
||||
) as stream:
|
||||
@ -232,7 +239,10 @@ async def _chat_completion(context: str, question: str) -> str:
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||
"content": (
|
||||
f"Documentation context:\n\n{context}"
|
||||
f"\n\n---\n\nQuestion: {question}"
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
@ -243,7 +253,10 @@ async def _chat_completion(context: str, question: str) -> str:
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||
"content": (
|
||||
f"Documentation context:\n\n{context}"
|
||||
f"\n\n---\n\nQuestion: {question}"
|
||||
),
|
||||
},
|
||||
]
|
||||
resp = await client.post(
|
||||
@ -284,11 +297,16 @@ async def ask_orrery(
|
||||
context, sources = await _build_context(question)
|
||||
except Exception:
|
||||
logger.exception("Failed to build RAG context")
|
||||
raise RuntimeError("Search failed while building context")
|
||||
raise RuntimeError(
|
||||
"Search failed while building context"
|
||||
) from None
|
||||
|
||||
if not context:
|
||||
return json.dumps({
|
||||
"answer": "No relevant documents found in the pg_orrery docs for this question.",
|
||||
"answer": (
|
||||
"No relevant documents found in the "
|
||||
"pg_orrery docs for this question."
|
||||
),
|
||||
"sources": [],
|
||||
})
|
||||
|
||||
@ -296,10 +314,18 @@ async def ask_orrery(
|
||||
answer = await _chat_completion(context, question)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
|
||||
raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}")
|
||||
status = exc.response.status_code
|
||||
raise RuntimeError(
|
||||
f"Chat model returned HTTP {status}"
|
||||
) from None
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout)
|
||||
raise RuntimeError("Chat model timed out — try a simpler question")
|
||||
logger.warning(
|
||||
"Chat completion timed out (limit: %.0fs)",
|
||||
settings.chat_timeout,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Chat model timed out — try a simpler question"
|
||||
) from None
|
||||
|
||||
result: dict = {"answer": answer}
|
||||
if include_sources:
|
||||
|
||||
@ -63,7 +63,10 @@ def _validate_sql(sql: str) -> None:
|
||||
raise ValueError("Only SELECT statements are allowed")
|
||||
|
||||
if _FORBIDDEN_RE.search(stripped):
|
||||
raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)")
|
||||
raise ValueError(
|
||||
"Statement contains forbidden keywords"
|
||||
" (INSERT/UPDATE/DELETE/DDL)"
|
||||
)
|
||||
|
||||
|
||||
def _serialize_row(row: asyncpg.Record) -> dict:
|
||||
@ -98,12 +101,12 @@ async def run_query(sql: str) -> str:
|
||||
_validate_sql(sql)
|
||||
|
||||
pool = await _get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction(readonly=True):
|
||||
await conn.execute(
|
||||
f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'"
|
||||
)
|
||||
rows = await conn.fetch(sql)
|
||||
async with pool.acquire() as conn, conn.transaction(readonly=True):
|
||||
timeout_ms = int(settings.run_query_timeout * 1000)
|
||||
await conn.execute(
|
||||
f"SET LOCAL statement_timeout = '{timeout_ms}'"
|
||||
)
|
||||
rows = await conn.fetch(sql)
|
||||
|
||||
if len(rows) > MAX_ROWS:
|
||||
rows = rows[:MAX_ROWS]
|
||||
@ -142,8 +145,13 @@ async def planet_position(
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
(t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s,
|
||||
(e).ra_hours, (e).dec_degrees, (e).distance_km
|
||||
topo_azimuth(t) AS az_deg,
|
||||
topo_elevation(t) AS el_deg,
|
||||
topo_range(t) AS range_km,
|
||||
topo_range_rate(t) AS range_rate_km_s,
|
||||
eq_ra(e) AS ra_hours,
|
||||
eq_dec(e) AS dec_deg,
|
||||
eq_distance(e) AS distance_km
|
||||
FROM (
|
||||
SELECT
|
||||
planet_observe({body_id}, {obs}, {ts}) AS t,
|
||||
@ -178,11 +186,11 @@ async def sky_survey(
|
||||
sql = f"""
|
||||
SELECT
|
||||
body_name,
|
||||
(topo).azimuth_deg AS az,
|
||||
(topo).elevation_deg AS el,
|
||||
(topo).range_km AS range_km,
|
||||
(eq).ra_hours AS ra,
|
||||
(eq).dec_degrees AS dec
|
||||
topo_azimuth(topo) AS az,
|
||||
topo_elevation(topo) AS el,
|
||||
topo_range(topo) AS range_km,
|
||||
eq_ra(eq) AS ra,
|
||||
eq_dec(eq) AS dec
|
||||
FROM (
|
||||
SELECT 'Sun' AS body_name,
|
||||
sun_observe({obs}, {ts}) AS topo,
|
||||
@ -192,9 +200,14 @@ async def sky_survey(
|
||||
moon_observe({obs}, {ts}),
|
||||
moon_equatorial({ts})
|
||||
UNION ALL
|
||||
SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']),
|
||||
planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}),
|
||||
planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
|
||||
SELECT unnest(ARRAY[
|
||||
'Mercury','Venus','Mars','Jupiter',
|
||||
'Saturn','Uranus','Neptune']),
|
||||
planet_observe(
|
||||
unnest(ARRAY[1,2,4,5,6,7,8]),
|
||||
{obs}, {ts}),
|
||||
planet_equatorial(
|
||||
unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
|
||||
) survey
|
||||
ORDER BY el DESC
|
||||
"""
|
||||
@ -210,11 +223,13 @@ async def satellite_pass(
|
||||
longitude: float,
|
||||
altitude: float = 0.0,
|
||||
time: str = "NOW()",
|
||||
count: int = 5,
|
||||
hours: int = 48,
|
||||
min_elevation: float = 0.0,
|
||||
) -> str:
|
||||
"""Predict upcoming satellite passes over an observer location.
|
||||
|
||||
Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
|
||||
Searches a time window from the start time forward.
|
||||
|
||||
Args:
|
||||
tle_line1: First line of the TLE (69 chars, starts with "1 ")
|
||||
@ -223,9 +238,10 @@ async def satellite_pass(
|
||||
longitude: Observer longitude in degrees
|
||||
altitude: Observer altitude in meters (default 0)
|
||||
time: Start time for pass search (default "NOW()")
|
||||
count: Number of passes to predict (default 5, max 20)
|
||||
hours: Hours to search forward from start time (default 48)
|
||||
min_elevation: Minimum peak elevation in degrees (default 0)
|
||||
"""
|
||||
count = max(1, min(count, 20))
|
||||
hours = max(1, min(hours, 168))
|
||||
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
||||
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
||||
|
||||
@ -235,12 +251,17 @@ async def satellite_pass(
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
(p).aos_time, (p).los_time,
|
||||
(p).max_elevation_deg,
|
||||
(p).aos_azimuth_deg, (p).los_azimuth_deg
|
||||
pass_aos_time(p) AS aos_time,
|
||||
pass_los_time(p) AS los_time,
|
||||
pass_max_elevation(p) AS max_el_deg,
|
||||
pass_aos_azimuth(p) AS aos_az_deg,
|
||||
pass_los_azimuth(p) AS los_az_deg,
|
||||
pass_duration(p) AS duration
|
||||
FROM predict_passes(
|
||||
tle_from_lines('{safe_l1}', '{safe_l2}'),
|
||||
{obs}, {ts}, {count}
|
||||
{obs}, {ts},
|
||||
{ts} + interval '{hours} hours',
|
||||
{min_elevation}
|
||||
) AS p
|
||||
"""
|
||||
|
||||
|
||||
@ -71,7 +71,8 @@ async def get_document(slug: str) -> str:
|
||||
Returns title, body text, metadata, and URL.
|
||||
|
||||
Args:
|
||||
slug: Document identifier (from search results, e.g. "guides/tracking-satellites")
|
||||
slug: Document identifier from search results
|
||||
(e.g. "guides/tracking-satellites")
|
||||
"""
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
@ -122,10 +123,17 @@ async def list_content(
|
||||
.order_by(Document.section)
|
||||
)
|
||||
groups = [
|
||||
{"section": row.section, "content_type": row.content_type, "count": row.count}
|
||||
{
|
||||
"section": row.section,
|
||||
"content_type": row.content_type,
|
||||
"count": row.count,
|
||||
}
|
||||
for row in result
|
||||
]
|
||||
return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)})
|
||||
total = sum(g["count"] for g in groups)
|
||||
return json.dumps(
|
||||
{"summary": groups, "total": total}
|
||||
)
|
||||
|
||||
limit = max(1, min(limit, 500))
|
||||
stmt = (
|
||||
@ -197,8 +205,9 @@ def orrery_expert(topic: str = "general") -> str:
|
||||
)
|
||||
topic_context = {
|
||||
"general": (
|
||||
"Help with any pg_orrery topic — satellite tracking, planetary observation, "
|
||||
"rise/set prediction, constellation identification, or Lagrange points."
|
||||
"Help with any pg_orrery topic — satellite tracking, "
|
||||
"planetary observation, rise/set prediction, "
|
||||
"constellation identification, or Lagrange points."
|
||||
),
|
||||
"satellites": (
|
||||
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
|
||||
@ -206,11 +215,13 @@ def orrery_expert(topic: str = "general") -> str:
|
||||
),
|
||||
"planets": (
|
||||
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
|
||||
"planet/sun/moon observation, equatorial coordinates, and apparent positions."
|
||||
"planet/sun/moon observation, equatorial coordinates, "
|
||||
"and apparent positions."
|
||||
),
|
||||
"navigation": (
|
||||
"Focus on observational astronomy: rise/set prediction, twilight computation, "
|
||||
"constellation identification, lunar phase, planet magnitude, and refraction."
|
||||
"Focus on observational astronomy: rise/set prediction, "
|
||||
"twilight computation, constellation identification, "
|
||||
"lunar phase, planet magnitude, and refraction."
|
||||
),
|
||||
"transfers": (
|
||||
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "
|
||||
|
||||
@ -48,7 +48,10 @@ async def chat(req: ChatRequest):
|
||||
context, sources = await _build_context(req.question)
|
||||
except Exception:
|
||||
logger.exception("Chat context build failed")
|
||||
raise HTTPException(status_code=502, detail="Search service unavailable")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Search service unavailable",
|
||||
) from None
|
||||
|
||||
if not context:
|
||||
return ChatResponse(
|
||||
@ -62,7 +65,10 @@ async def chat(req: ChatRequest):
|
||||
answer = await _chat_completion(context, req.question)
|
||||
except Exception:
|
||||
logger.exception("Chat completion failed")
|
||||
raise HTTPException(status_code=502, detail="Chat completion unavailable")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Chat completion unavailable",
|
||||
) from None
|
||||
|
||||
return ChatResponse(
|
||||
answer=answer,
|
||||
@ -105,7 +111,12 @@ async def chat_stream(req: ChatStreamRequest):
|
||||
async def generate():
|
||||
question = req.question
|
||||
if req.page and req.page.title:
|
||||
page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})'
|
||||
title = req.page.title
|
||||
path = req.page.path
|
||||
page_context = (
|
||||
f'[The user is currently reading: '
|
||||
f'"{title}" ({path})'
|
||||
)
|
||||
if req.page.description:
|
||||
page_context += f"\nDocument description: {req.page.description}"
|
||||
page_context += "]\n\n"
|
||||
@ -127,7 +138,8 @@ async def chat_stream(req: ChatStreamRequest):
|
||||
{
|
||||
"text": "I couldn't find any relevant documentation "
|
||||
"for that question. Try asking about satellite tracking, "
|
||||
"planetary observation, rise/set prediction, or other pg_orrery functions."
|
||||
"planetary observation, rise/set prediction, "
|
||||
"or other pg_orrery functions."
|
||||
},
|
||||
)
|
||||
yield _sse_event("done", {})
|
||||
@ -136,7 +148,13 @@ async def chat_stream(req: ChatStreamRequest):
|
||||
n = len(sources)
|
||||
yield _sse_event(
|
||||
"status",
|
||||
{"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"},
|
||||
{
|
||||
"text": (
|
||||
f"Found {n} relevant "
|
||||
f"page{'s' if n != 1 else ''}, "
|
||||
"generating answer\u2026"
|
||||
),
|
||||
},
|
||||
)
|
||||
yield _sse_event("sources", sources)
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@ router = APIRouter()
|
||||
async def search(
|
||||
q: str = Query(..., min_length=1, max_length=500),
|
||||
content_type: str | None = Query(None, description="Filter by content type"),
|
||||
section: str | None = Query(None, max_length=200, description="Section prefix filter"),
|
||||
section: str | None = Query(
|
||||
None, max_length=200, description="Section prefix filter"
|
||||
),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@ -67,8 +69,8 @@ async def list_sections(
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
sections = [
|
||||
SectionCount(section=row.section, count=row.count)
|
||||
SectionCount(section=row[0], count=row[1])
|
||||
for row in result
|
||||
if row.section
|
||||
if row[0]
|
||||
]
|
||||
return SectionsResponse(sections=sections)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user