Add Origin checking
This commit is contained in:
parent
1e737ed44b
commit
e462164321
@ -11,12 +11,16 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
# Allowed origins for CORS/CSRF protection
|
||||||
|
ALLOWED_ORIGINS = os.environ.get("GHIDRA_ALLOWED_ORIGINS", "http://localhost").split(",")
|
||||||
|
|
||||||
# Track active Ghidra instances (port -> info dict)
|
# Track active Ghidra instances (port -> info dict)
|
||||||
active_instances: Dict[int, dict] = {}
|
active_instances: Dict[int, dict] = {}
|
||||||
instances_lock = Lock()
|
instances_lock = Lock()
|
||||||
@ -51,12 +55,38 @@ def get_instance_url(port: int) -> str:
|
|||||||
|
|
||||||
return f"http://{ghidra_host}:{port}"
|
return f"http://{ghidra_host}:{port}"
|
||||||
|
|
||||||
|
def validate_origin(headers: dict) -> bool:
|
||||||
|
"""Validate request origin against allowed origins"""
|
||||||
|
origin = headers.get("Origin")
|
||||||
|
if not origin:
|
||||||
|
return True # No origin header - allow (browser same-origin policy applies)
|
||||||
|
|
||||||
|
# Parse origin to get scheme+hostname
|
||||||
|
try:
|
||||||
|
parsed = urlparse(origin)
|
||||||
|
origin_base = f"{parsed.scheme}://{parsed.hostname}"
|
||||||
|
if parsed.port:
|
||||||
|
origin_base += f":{parsed.port}"
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return origin_base in ALLOWED_ORIGINS
|
||||||
|
|
||||||
def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
|
def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
|
||||||
"""Perform a GET request to a specific Ghidra instance and return JSON response"""
|
"""Perform a GET request to a specific Ghidra instance and return JSON response"""
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
url = f"{get_instance_url(port)}/{endpoint}"
|
url = f"{get_instance_url(port)}/{endpoint}"
|
||||||
|
|
||||||
|
# Check origin if this is a state-changing request
|
||||||
|
if endpoint not in ["instances", "info"] and not validate_origin(params.get("headers", {})):
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Origin not allowed",
|
||||||
|
"status_code": 403,
|
||||||
|
"timestamp": int(time.time() * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
@ -131,6 +161,14 @@ def safe_put(port: int, endpoint: str, data: dict) -> dict:
|
|||||||
"""Perform a PUT request to a specific Ghidra instance with JSON payload"""
|
"""Perform a PUT request to a specific Ghidra instance with JSON payload"""
|
||||||
try:
|
try:
|
||||||
url = f"{get_instance_url(port)}/{endpoint}"
|
url = f"{get_instance_url(port)}/{endpoint}"
|
||||||
|
|
||||||
|
# Always validate origin for PUT requests
|
||||||
|
if not validate_origin(data.get("headers", {})):
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Origin not allowed",
|
||||||
|
"status_code": 403
|
||||||
|
}
|
||||||
response = requests.put(
|
response = requests.put(
|
||||||
url,
|
url,
|
||||||
json=data,
|
json=data,
|
||||||
@ -183,6 +221,15 @@ def safe_post(port: int, endpoint: str, data: dict | str) -> dict:
|
|||||||
"""Perform a POST request to a specific Ghidra instance with JSON payload"""
|
"""Perform a POST request to a specific Ghidra instance with JSON payload"""
|
||||||
try:
|
try:
|
||||||
url = f"{get_instance_url(port)}/{endpoint}"
|
url = f"{get_instance_url(port)}/{endpoint}"
|
||||||
|
|
||||||
|
# Always validate origin for POST requests
|
||||||
|
headers = data.get("headers", {}) if isinstance(data, dict) else {}
|
||||||
|
if not validate_origin(headers):
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Origin not allowed",
|
||||||
|
"status_code": 403
|
||||||
|
}
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user