Add Origin checking
This commit is contained in:
parent
1e737ed44b
commit
e462164321
@ -11,12 +11,16 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Allowed origins for CORS/CSRF protection
|
||||
ALLOWED_ORIGINS = os.environ.get("GHIDRA_ALLOWED_ORIGINS", "http://localhost").split(",")
|
||||
|
||||
# Track active Ghidra instances (port -> info dict)
|
||||
active_instances: Dict[int, dict] = {}
|
||||
instances_lock = Lock()
|
||||
@ -51,6 +55,23 @@ def get_instance_url(port: int) -> str:
|
||||
|
||||
return f"http://{ghidra_host}:{port}"
|
||||
|
||||
def validate_origin(headers: dict) -> bool:
|
||||
"""Validate request origin against allowed origins"""
|
||||
origin = headers.get("Origin")
|
||||
if not origin:
|
||||
return True # No origin header - allow (browser same-origin policy applies)
|
||||
|
||||
# Parse origin to get scheme+hostname
|
||||
try:
|
||||
parsed = urlparse(origin)
|
||||
origin_base = f"{parsed.scheme}://{parsed.hostname}"
|
||||
if parsed.port:
|
||||
origin_base += f":{parsed.port}"
|
||||
except:
|
||||
return False
|
||||
|
||||
return origin_base in ALLOWED_ORIGINS
|
||||
|
||||
def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
|
||||
"""Perform a GET request to a specific Ghidra instance and return JSON response"""
|
||||
if params is None:
|
||||
@ -58,6 +79,15 @@ def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
|
||||
|
||||
url = f"{get_instance_url(port)}/{endpoint}"
|
||||
|
||||
# Check origin if this is a state-changing request
|
||||
if endpoint not in ["instances", "info"] and not validate_origin(params.get("headers", {})):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Origin not allowed",
|
||||
"status_code": 403,
|
||||
"timestamp": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
@ -131,6 +161,14 @@ def safe_put(port: int, endpoint: str, data: dict) -> dict:
|
||||
"""Perform a PUT request to a specific Ghidra instance with JSON payload"""
|
||||
try:
|
||||
url = f"{get_instance_url(port)}/{endpoint}"
|
||||
|
||||
# Always validate origin for PUT requests
|
||||
if not validate_origin(data.get("headers", {})):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Origin not allowed",
|
||||
"status_code": 403
|
||||
}
|
||||
response = requests.put(
|
||||
url,
|
||||
json=data,
|
||||
@ -184,6 +222,15 @@ def safe_post(port: int, endpoint: str, data: dict | str) -> dict:
|
||||
try:
|
||||
url = f"{get_instance_url(port)}/{endpoint}"
|
||||
|
||||
# Always validate origin for POST requests
|
||||
headers = data.get("headers", {}) if isinstance(data, dict) else {}
|
||||
if not validate_origin(headers):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Origin not allowed",
|
||||
"status_code": 403
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
response = requests.post(
|
||||
url,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user