mcilspy/src/mcilspy/metadata_reader.py

554 lines
21 KiB
Python

"""Direct .NET metadata reader using dnfile.
Provides access to all 34+ CLR metadata tables without requiring ilspycmd.
This enables searching for methods, fields, properties, events, and resources
that are not exposed via the ilspycmd CLI.
This module contains CPU-bound synchronous code for parsing .NET PE metadata.
For heavy workloads with many concurrent requests, consider running these
operations in a thread pool (e.g., asyncio.to_thread) to avoid blocking
the event loop.
Note: dnfile provides flag attributes as boolean properties (e.g., mdPublic, fdStatic)
rather than traditional IntFlag enums, so we use those directly.
"""
import logging
from pathlib import Path
from typing import Any
import dnfile
from dnfile.mdtable import TypeDefRow
from .models import (
AssemblyMetadata,
EventInfo,
FieldInfo,
MethodInfo,
PropertyInfo,
ResourceInfo,
)
logger = logging.getLogger(__name__)
# Maximum assembly file size to load (in megabytes)
# Prevents memory exhaustion from extremely large or malicious assemblies
MAX_ASSEMBLY_SIZE_MB = 500
class AssemblySizeError(ValueError):
"""Raised when an assembly exceeds the maximum allowed size."""
pass
class MetadataReader:
"""Read .NET assembly metadata directly using dnfile."""
def __init__(self, assembly_path: str) -> None:
"""Initialize the metadata reader.
Args:
assembly_path: Path to the .NET assembly file
Raises:
FileNotFoundError: If the assembly file doesn't exist
AssemblySizeError: If the assembly exceeds MAX_ASSEMBLY_SIZE_MB
"""
self.assembly_path = Path(assembly_path)
if not self.assembly_path.exists():
raise FileNotFoundError(f"Assembly not found: {assembly_path}")
# Check file size before loading to prevent memory exhaustion
file_size_bytes = self.assembly_path.stat().st_size
max_size_bytes = MAX_ASSEMBLY_SIZE_MB * 1024 * 1024
if file_size_bytes > max_size_bytes:
size_mb = file_size_bytes / (1024 * 1024)
raise AssemblySizeError(
f"Assembly file size ({size_mb:.1f} MB) exceeds maximum allowed "
f"({MAX_ASSEMBLY_SIZE_MB} MB). This limit prevents memory exhaustion."
)
self._pe: dnfile.dnPE | None = None
self._type_cache: dict[int, TypeDefRow] = {}
def _ensure_loaded(self) -> dnfile.dnPE:
"""Ensure the PE file is loaded."""
if self._pe is None:
try:
self._pe = dnfile.dnPE(str(self.assembly_path))
except Exception as e:
raise ValueError(f"Failed to parse assembly: {e}") from e
# Build type cache for lookups
if self._pe.net and self._pe.net.mdtables and self._pe.net.mdtables.TypeDef:
for i, td in enumerate(self._pe.net.mdtables.TypeDef):
self._type_cache[i + 1] = td # Metadata tokens are 1-indexed
return self._pe
def get_assembly_metadata(self) -> AssemblyMetadata:
"""Get comprehensive assembly metadata."""
pe = self._ensure_loaded()
name = self.assembly_path.stem
version = "0.0.0.0"
culture = None
public_key_token = None
target_framework = None
referenced_assemblies = []
if pe.net and pe.net.mdtables:
# Assembly info
if pe.net.mdtables.Assembly:
for asm in pe.net.mdtables.Assembly:
name = str(asm.Name) if asm.Name else name
version = f"{asm.MajorVersion}.{asm.MinorVersion}.{asm.BuildNumber}.{asm.RevisionNumber}"
culture = str(asm.Culture) if asm.Culture else None
if asm.PublicKey:
# Convert to token format - handle various dnfile representations
try:
pk = asm.PublicKey
if hasattr(pk, "value"):
pk_bytes = bytes(pk.value)
elif hasattr(pk, "__bytes__") or isinstance(pk, (bytes, bytearray)):
pk_bytes = bytes(pk)
else:
pk_bytes = b""
if pk_bytes:
public_key_token = (
pk_bytes[-8:].hex() if len(pk_bytes) >= 8 else pk_bytes.hex()
)
except (TypeError, AttributeError):
# Some dnfile versions can't convert HeapItemBinary directly
pass
# Assembly references
if pe.net.mdtables.AssemblyRef:
for ref in pe.net.mdtables.AssemblyRef:
ref_name = str(ref.Name) if ref.Name else "Unknown"
ref_version = f"{ref.MajorVersion}.{ref.MinorVersion}.{ref.BuildNumber}.{ref.RevisionNumber}"
referenced_assemblies.append(f"{ref_name}, Version={ref_version}")
# Try to find TargetFramework from custom attributes
if pe.net.mdtables.CustomAttribute:
for ca in pe.net.mdtables.CustomAttribute:
# Look for TargetFrameworkAttribute
try:
if hasattr(ca, "Type") and ca.Type:
type_name = str(ca.Type) if ca.Type else ""
if "TargetFramework" in type_name and hasattr(ca, "Value") and ca.Value:
target_framework = str(ca.Value)
except Exception:
pass
type_count = (
len(pe.net.mdtables.TypeDef)
if pe.net and pe.net.mdtables and pe.net.mdtables.TypeDef
else 0
)
method_count = (
len(pe.net.mdtables.MethodDef)
if pe.net and pe.net.mdtables and pe.net.mdtables.MethodDef
else 0
)
field_count = (
len(pe.net.mdtables.Field)
if pe.net and pe.net.mdtables and pe.net.mdtables.Field
else 0
)
property_count = (
len(pe.net.mdtables.Property)
if pe.net and pe.net.mdtables and pe.net.mdtables.Property
else 0
)
event_count = (
len(pe.net.mdtables.Event)
if pe.net and pe.net.mdtables and pe.net.mdtables.Event
else 0
)
resource_count = (
len(pe.net.mdtables.ManifestResource)
if pe.net and pe.net.mdtables and pe.net.mdtables.ManifestResource
else 0
)
return AssemblyMetadata(
name=name,
version=version,
culture=culture,
public_key_token=public_key_token,
target_framework=target_framework,
type_count=type_count,
method_count=method_count,
field_count=field_count,
property_count=property_count,
event_count=event_count,
resource_count=resource_count,
referenced_assemblies=referenced_assemblies,
)
def _get_row_index(self, reference: Any) -> int:
"""Safely extract row_index from a metadata reference.
dnfile references can be either objects with .row_index attribute
or raw integers. This helper handles both cases.
"""
if reference is None:
return 0
if hasattr(reference, "row_index"):
return reference.row_index
if isinstance(reference, int):
return reference
# Some dnfile versions return the index directly as an attribute
if hasattr(reference, "value"):
return reference.value
return 0
def list_methods(
self,
type_filter: str | None = None,
namespace_filter: str | None = None,
public_only: bool = False,
) -> list[MethodInfo]:
"""List all methods in the assembly.
Args:
type_filter: Only return methods from types containing this string
namespace_filter: Only return methods from types in namespaces containing this string
public_only: Only return public methods
"""
pe = self._ensure_loaded()
methods = []
if not (
pe.net and pe.net.mdtables and pe.net.mdtables.TypeDef and pe.net.mdtables.MethodDef
):
return methods
# Build method-to-type mapping
# TypeDef.MethodList points to the first method of each type
type_method_ranges: list[tuple[TypeDefRow, int, int]] = []
type_defs = list(pe.net.mdtables.TypeDef)
method_count = len(pe.net.mdtables.MethodDef)
for i, td in enumerate(type_defs):
start_idx = self._get_row_index(td.MethodList)
if i + 1 < len(type_defs):
next_method_list = type_defs[i + 1].MethodList
end_idx = self._get_row_index(next_method_list) or (method_count + 1)
else:
end_idx = method_count + 1
type_method_ranges.append((td, start_idx, end_idx))
# Iterate through methods
for md_idx, md in enumerate(pe.net.mdtables.MethodDef, start=1):
# Find declaring type
declaring_type = None
namespace = None
for td, start, end in type_method_ranges:
if start <= md_idx < end:
declaring_type = str(td.TypeName) if td.TypeName else "Unknown"
namespace = str(td.TypeNamespace) if td.TypeNamespace else None
break
if declaring_type is None:
declaring_type = "Unknown"
# Apply filters
if type_filter and type_filter.lower() not in declaring_type.lower():
continue
if namespace_filter and (
namespace is None or namespace_filter.lower() not in namespace.lower()
):
continue
# Parse attributes using dnfile's boolean properties on ClrMethodAttr
flags = md.Flags if hasattr(md, "Flags") else None
is_public = flags.mdPublic if flags and hasattr(flags, "mdPublic") else False
is_static = flags.mdStatic if flags and hasattr(flags, "mdStatic") else False
is_virtual = flags.mdVirtual if flags and hasattr(flags, "mdVirtual") else False
is_abstract = flags.mdAbstract if flags and hasattr(flags, "mdAbstract") else False
if public_only and not is_public:
continue
method_name = str(md.Name) if md.Name else "Unknown"
full_name = (
f"{namespace}.{declaring_type}.{method_name}"
if namespace
else f"{declaring_type}.{method_name}"
)
methods.append(
MethodInfo(
name=method_name,
full_name=full_name,
declaring_type=declaring_type,
namespace=namespace,
is_public=is_public,
is_static=is_static,
is_virtual=is_virtual,
is_abstract=is_abstract,
)
)
return methods
def list_fields(
self,
type_filter: str | None = None,
namespace_filter: str | None = None,
public_only: bool = False,
constants_only: bool = False,
) -> list[FieldInfo]:
"""List all fields in the assembly.
Args:
type_filter: Only return fields from types containing this string
namespace_filter: Only return fields from types in namespaces containing this string
public_only: Only return public fields
constants_only: Only return constant (literal) fields
"""
pe = self._ensure_loaded()
fields = []
if not (pe.net and pe.net.mdtables and pe.net.mdtables.TypeDef and pe.net.mdtables.Field):
return fields
# Build field-to-type mapping
type_defs = list(pe.net.mdtables.TypeDef)
field_count = len(pe.net.mdtables.Field)
type_field_ranges: list[tuple[TypeDefRow, int, int]] = []
for i, td in enumerate(type_defs):
start_idx = self._get_row_index(td.FieldList)
if i + 1 < len(type_defs):
next_field_list = type_defs[i + 1].FieldList
end_idx = self._get_row_index(next_field_list) or (field_count + 1)
else:
end_idx = field_count + 1
type_field_ranges.append((td, start_idx, end_idx))
# Iterate through fields
for f_idx, fld in enumerate(pe.net.mdtables.Field, start=1):
# Find declaring type
declaring_type = None
namespace = None
for td, start, end in type_field_ranges:
if start <= f_idx < end:
declaring_type = str(td.TypeName) if td.TypeName else "Unknown"
namespace = str(td.TypeNamespace) if td.TypeNamespace else None
break
if declaring_type is None:
declaring_type = "Unknown"
# Apply filters
if type_filter and type_filter.lower() not in declaring_type.lower():
continue
if namespace_filter and (
namespace is None or namespace_filter.lower() not in namespace.lower()
):
continue
# Parse attributes using dnfile's boolean properties on ClrFieldAttr
flags = fld.Flags if hasattr(fld, "Flags") else None
is_public = flags.fdPublic if flags and hasattr(flags, "fdPublic") else False
is_static = flags.fdStatic if flags and hasattr(flags, "fdStatic") else False
is_literal = flags.fdLiteral if flags and hasattr(flags, "fdLiteral") else False
if public_only and not is_public:
continue
if constants_only and not is_literal:
continue
field_name = str(fld.Name) if fld.Name else "Unknown"
full_name = (
f"{namespace}.{declaring_type}.{field_name}"
if namespace
else f"{declaring_type}.{field_name}"
)
fields.append(
FieldInfo(
name=field_name,
full_name=full_name,
declaring_type=declaring_type,
namespace=namespace,
is_public=is_public,
is_static=is_static,
is_literal=is_literal,
)
)
return fields
def list_properties(
self,
type_filter: str | None = None,
namespace_filter: str | None = None,
) -> list[PropertyInfo]:
"""List all properties in the assembly."""
pe = self._ensure_loaded()
properties = []
if not (pe.net and pe.net.mdtables and pe.net.mdtables.Property):
return properties
# PropertyMap links types to properties
property_type_map: dict[int, tuple[str, str | None]] = {}
if pe.net.mdtables.PropertyMap and pe.net.mdtables.TypeDef:
prop_maps = list(pe.net.mdtables.PropertyMap)
type_defs = list(pe.net.mdtables.TypeDef)
for i, pm in enumerate(prop_maps):
if pm.Parent and pm.PropertyList:
parent_idx = self._get_row_index(pm.Parent)
if parent_idx > 0 and parent_idx <= len(type_defs):
td = type_defs[parent_idx - 1]
type_name = str(td.TypeName) if td.TypeName else "Unknown"
ns = str(td.TypeNamespace) if td.TypeNamespace else None
# Determine property range
start_idx = self._get_row_index(pm.PropertyList)
if i + 1 < len(prop_maps):
end_idx = self._get_row_index(prop_maps[i + 1].PropertyList)
else:
end_idx = len(pe.net.mdtables.Property) + 1
for p_idx in range(start_idx, end_idx):
property_type_map[p_idx] = (type_name, ns)
# Iterate through properties
for p_idx, prop in enumerate(pe.net.mdtables.Property, start=1):
declaring_type, namespace = property_type_map.get(p_idx, ("Unknown", None))
# Apply filters
if type_filter and type_filter.lower() not in declaring_type.lower():
continue
if namespace_filter and (
namespace is None or namespace_filter.lower() not in namespace.lower()
):
continue
prop_name = str(prop.Name) if prop.Name else "Unknown"
full_name = (
f"{namespace}.{declaring_type}.{prop_name}"
if namespace
else f"{declaring_type}.{prop_name}"
)
properties.append(
PropertyInfo(
name=prop_name,
full_name=full_name,
declaring_type=declaring_type,
namespace=namespace,
)
)
return properties
def list_events(
self,
type_filter: str | None = None,
namespace_filter: str | None = None,
) -> list[EventInfo]:
"""List all events in the assembly."""
pe = self._ensure_loaded()
events = []
if not (pe.net and pe.net.mdtables and pe.net.mdtables.Event):
return events
# EventMap links types to events
event_type_map: dict[int, tuple[str, str | None]] = {}
if pe.net.mdtables.EventMap and pe.net.mdtables.TypeDef:
event_maps = list(pe.net.mdtables.EventMap)
type_defs = list(pe.net.mdtables.TypeDef)
for i, em in enumerate(event_maps):
if em.Parent and em.EventList:
parent_idx = self._get_row_index(em.Parent)
if parent_idx > 0 and parent_idx <= len(type_defs):
td = type_defs[parent_idx - 1]
type_name = str(td.TypeName) if td.TypeName else "Unknown"
ns = str(td.TypeNamespace) if td.TypeNamespace else None
start_idx = self._get_row_index(em.EventList)
if i + 1 < len(event_maps):
end_idx = self._get_row_index(event_maps[i + 1].EventList)
else:
end_idx = len(pe.net.mdtables.Event) + 1
for e_idx in range(start_idx, end_idx):
event_type_map[e_idx] = (type_name, ns)
for e_idx, evt in enumerate(pe.net.mdtables.Event, start=1):
declaring_type, namespace = event_type_map.get(e_idx, ("Unknown", None))
if type_filter and type_filter.lower() not in declaring_type.lower():
continue
if namespace_filter and (
namespace is None or namespace_filter.lower() not in namespace.lower()
):
continue
evt_name = str(evt.Name) if evt.Name else "Unknown"
full_name = (
f"{namespace}.{declaring_type}.{evt_name}"
if namespace
else f"{declaring_type}.{evt_name}"
)
events.append(
EventInfo(
name=evt_name,
full_name=full_name,
declaring_type=declaring_type,
namespace=namespace,
)
)
return events
def list_resources(self) -> list[ResourceInfo]:
"""List all embedded resources in the assembly."""
pe = self._ensure_loaded()
resources = []
if not (pe.net and pe.net.mdtables and pe.net.mdtables.ManifestResource):
return resources
for res in pe.net.mdtables.ManifestResource:
name = str(res.Name) if res.Name else "Unknown"
# dnfile exposes flags as boolean properties (mrPublic, mrPrivate)
is_public = (
res.Flags.mrPublic
if hasattr(res, "Flags") and hasattr(res.Flags, "mrPublic")
else True
)
resources.append(
ResourceInfo(
name=name,
size=0, # Size requires reading the actual resource data
is_public=is_public,
)
)
return resources
def close(self) -> None:
"""Close the PE file."""
if self._pe:
self._pe.close()
self._pe = None
def __enter__(self) -> "MetadataReader":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
return False