diff --git a/pyproject.toml b/pyproject.toml index c2cd13f..5893759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ dev = [ "ruff>=0.1.0", "pytest>=7.0.0", + "pytest-asyncio>=0.23.0", ] plot = [ "matplotlib>=3.7.0", @@ -46,6 +47,10 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/mcp_ltspice"] +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" + [tool.ruff] line-length = 100 target-version = "py311" diff --git a/src/mcp_ltspice/log_parser.py b/src/mcp_ltspice/log_parser.py index e37d751..7ca883d 100644 --- a/src/mcp_ltspice/log_parser.py +++ b/src/mcp_ltspice/log_parser.py @@ -25,6 +25,8 @@ class SimulationLog: n_equations: int | None = None n_steps: int | None = None raw_text: str = "" + operating_point: dict[str, float] = field(default_factory=dict) + transfer_function: dict[str, float] = field(default_factory=dict) def get_measurement(self, name: str) -> Measurement | None: """Get a measurement by name (case-insensitive).""" @@ -72,6 +74,17 @@ _N_STEPS_RE = re.compile( # Lines starting with ".meas" are directive echoes, not results -- skip them. _MEAS_DIRECTIVE_RE = re.compile(r"^\s*\.meas\s", re.IGNORECASE) +# Operating point / transfer function lines: "V(out):\t 2.5\t voltage" +# These have a name, colon, value, then optional trailing text (units/type). +_OP_TF_VALUE_RE = re.compile( + r"^(?P\S+?):\s+(?P[+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)\s*", + re.IGNORECASE, +) + +# Section headers in LTspice log files +_OP_SECTION_RE = re.compile(r"---\s*Operating\s+Point\s*---", re.IGNORECASE) +_TF_SECTION_RE = re.compile(r"---\s*Transfer\s+Function\s*---", re.IGNORECASE) + def _is_error_line(line: str) -> bool: """Return True if the line reports an error.""" @@ -106,11 +119,39 @@ def parse_log(path: Path | str) -> SimulationLog: log = SimulationLog(raw_text=raw_text) + # Track which section we're currently parsing + current_section: str | None = None # "op", "tf", or None + for line in raw_text.splitlines(): stripped = line.strip() if not stripped: continue + # Detect section headers + if _OP_SECTION_RE.search(stripped): + current_section = "op" + continue + if _TF_SECTION_RE.search(stripped): + current_section = "tf" + continue + + # A new section header (any "---...---" line) ends the current section + if stripped.startswith("---") and stripped.endswith("---"): + current_section = None + continue + + # Parse .op / .tf section values + if current_section in ("op", "tf"): + m = _OP_TF_VALUE_RE.match(stripped) + if m: + try: + val = float(m.group("value")) + except ValueError: + continue + target = log.operating_point if current_section == "op" else log.transfer_function + target[m.group("name")] = val + continue + # Skip echoed .meas directives -- they are not results. if _MEAS_DIRECTIVE_RE.match(stripped): continue diff --git a/src/mcp_ltspice/noise_analysis.py b/src/mcp_ltspice/noise_analysis.py new file mode 100644 index 0000000..7c8ba71 --- /dev/null +++ b/src/mcp_ltspice/noise_analysis.py @@ -0,0 +1,365 @@ +"""Noise analysis for LTspice .noise simulation results. + +LTspice .noise analysis produces output with variables like 'onoise' +(output-referred noise spectral density in V/sqrt(Hz)) and 'inoise' +(input-referred noise spectral density). The data is complex-valued +in the .raw file; magnitude gives the spectral density. +""" + +import numpy as np + +# np.trapz was renamed to np.trapezoid in numpy 2.0 +_trapz = getattr(np, "trapezoid", getattr(np, "trapz", None)) + +# Boltzmann constant (J/K) +_K_BOLTZMANN = 1.380649e-23 + + +def compute_noise_spectral_density(frequency: np.ndarray, noise_signal: np.ndarray) -> dict: + """Compute noise spectral density from raw noise simulation data. + + Takes the frequency array and complex noise signal directly from the + .raw file and returns the noise spectral density in V/sqrt(Hz) and dB. + + Args: + frequency: Frequency array in Hz (may be complex; real part is used) + noise_signal: Complex noise signal from .raw file (magnitude = V/sqrt(Hz)) + + Returns: + Dict with frequency_hz, noise_density_v_per_sqrt_hz, noise_density_db + """ + if len(frequency) == 0 or len(noise_signal) == 0: + return { + "frequency_hz": [], + "noise_density_v_per_sqrt_hz": [], + "noise_density_db": [], + } + + freq = np.real(frequency).astype(np.float64) + density = np.abs(noise_signal) + + # Single-point case: still return valid data + density_db = 20.0 * np.log10(np.maximum(density, 1e-30)) + + return { + "frequency_hz": freq.tolist(), + "noise_density_v_per_sqrt_hz": density.tolist(), + "noise_density_db": density_db.tolist(), + } + + +def compute_total_noise( + frequency: np.ndarray, + noise_signal: np.ndarray, + f_low: float | None = None, + f_high: float | None = None, +) -> dict: + """Integrate noise spectral density over frequency to get total RMS noise. + + Computes total_rms = sqrt(integral(|noise|^2 * df)) using trapezoidal + integration over the specified frequency range. + + Args: + frequency: Frequency array in Hz (may be complex; real part is used) + noise_signal: Complex noise signal from .raw file + f_low: Lower integration bound in Hz (default: min frequency in data) + f_high: Upper integration bound in Hz (default: max frequency in data) + + Returns: + Dict with total_rms_v, integration_range_hz, equivalent_noise_bandwidth_hz + """ + if len(frequency) < 2 or len(noise_signal) < 2: + return { + "total_rms_v": 0.0, + "integration_range_hz": [0.0, 0.0], + "equivalent_noise_bandwidth_hz": 0.0, + } + + freq = np.real(frequency).astype(np.float64) + density = np.abs(noise_signal) + + # Sort by frequency to ensure correct integration order + sort_idx = np.argsort(freq) + freq = freq[sort_idx] + density = density[sort_idx] + + # Apply frequency bounds + if f_low is None: + f_low = float(freq[0]) + if f_high is None: + f_high = float(freq[-1]) + + mask = (freq >= f_low) & (freq <= f_high) + freq_band = freq[mask] + density_band = density[mask] + + if len(freq_band) < 2: + return { + "total_rms_v": 0.0, + "integration_range_hz": [f_low, f_high], + "equivalent_noise_bandwidth_hz": 0.0, + } + + # Integrate |noise|^2 over frequency, then take sqrt for RMS + noise_power = density_band**2 + integrated = float(_trapz(noise_power, freq_band)) + total_rms = float(np.sqrt(max(integrated, 0.0))) + + # Equivalent noise bandwidth: bandwidth of a brick-wall filter with the + # same peak density that would pass the same total noise power + peak_density = float(np.max(density_band)) + if peak_density > 1e-30: + enbw = integrated / (peak_density**2) + else: + enbw = 0.0 + + return { + "total_rms_v": total_rms, + "integration_range_hz": [f_low, f_high], + "equivalent_noise_bandwidth_hz": float(enbw), + } + + +def compute_spot_noise(frequency: np.ndarray, noise_signal: np.ndarray, target_freq: float) -> dict: + """Interpolate noise spectral density at a specific frequency. + + Uses linear interpolation between adjacent data points to estimate + the noise density at the requested frequency. + + Args: + frequency: Frequency array in Hz (may be complex; real part is used) + noise_signal: Complex noise signal from .raw file + target_freq: Desired frequency in Hz + + Returns: + Dict with spot_noise_v_per_sqrt_hz, spot_noise_db, actual_freq_hz + """ + if len(frequency) == 0 or len(noise_signal) == 0: + return { + "spot_noise_v_per_sqrt_hz": 0.0, + "spot_noise_db": float("-inf"), + "actual_freq_hz": target_freq, + } + + freq = np.real(frequency).astype(np.float64) + density = np.abs(noise_signal) + + # Sort by frequency + sort_idx = np.argsort(freq) + freq = freq[sort_idx] + density = density[sort_idx] + + # Clamp to data range + if target_freq <= freq[0]: + spot = float(density[0]) + actual = float(freq[0]) + elif target_freq >= freq[-1]: + spot = float(density[-1]) + actual = float(freq[-1]) + else: + # Linear interpolation + spot = float(np.interp(target_freq, freq, density)) + actual = target_freq + + spot_db = 20.0 * np.log10(max(spot, 1e-30)) + + return { + "spot_noise_v_per_sqrt_hz": spot, + "spot_noise_db": spot_db, + "actual_freq_hz": actual, + } + + +def compute_noise_figure( + frequency: np.ndarray, + noise_signal: np.ndarray, + source_resistance: float = 50.0, + temperature: float = 290.0, +) -> dict: + """Compute noise figure from noise spectral density. + + Noise figure is the ratio of the measured output noise power to + the thermal noise of the source resistance at the given temperature. + NF(f) = 10*log10(|noise(f)|^2 / (4*k*T*R)) + + Args: + frequency: Frequency array in Hz (may be complex; real part is used) + noise_signal: Complex noise signal from .raw file (output-referred) + source_resistance: Source impedance in ohms (default 50) + temperature: Temperature in Kelvin (default 290 K, IEEE standard) + + Returns: + Dict with noise_figure_db (array), frequency_hz (array), + min_nf_db, nf_at_1khz + """ + if len(frequency) == 0 or len(noise_signal) == 0: + return { + "noise_figure_db": [], + "frequency_hz": [], + "min_nf_db": None, + "nf_at_1khz": None, + } + + freq = np.real(frequency).astype(np.float64) + density = np.abs(noise_signal) + + # Thermal noise power spectral density of the source: 4*k*T*R (V^2/Hz) + thermal_psd = 4.0 * _K_BOLTZMANN * temperature * source_resistance + + if thermal_psd < 1e-50: + return { + "noise_figure_db": [], + "frequency_hz": freq.tolist(), + "min_nf_db": None, + "nf_at_1khz": None, + } + + # NF = 10*log10(measured_noise_power / thermal_noise_power) + # where noise_power = density^2 per Hz + noise_power = density**2 + nf_ratio = noise_power / thermal_psd + nf_db = 10.0 * np.log10(np.maximum(nf_ratio, 1e-30)) + + min_nf_db = float(np.min(nf_db)) + + # Noise figure at 1 kHz (interpolated) + nf_at_1khz = None + if freq[0] <= 1000.0 <= freq[-1]: + sort_idx = np.argsort(freq) + nf_at_1khz = float(np.interp(1000.0, freq[sort_idx], nf_db[sort_idx])) + elif len(freq) == 1: + nf_at_1khz = float(nf_db[0]) + + return { + "noise_figure_db": nf_db.tolist(), + "frequency_hz": freq.tolist(), + "min_nf_db": min_nf_db, + "nf_at_1khz": nf_at_1khz, + } + + +def _estimate_flicker_corner(frequency: np.ndarray, density: np.ndarray) -> float | None: + """Estimate the 1/f noise corner frequency. + + The 1/f corner is where the noise transitions from 1/f (flicker) behavior + to flat (white) noise. We find where the slope of log(density) vs log(freq) + crosses -0.25 (midpoint between 0 for white and -0.5 for 1/f in V/sqrt(Hz)). + + Args: + frequency: Sorted frequency array in Hz (positive, ascending) + density: Noise spectral density magnitude (same order as frequency) + + Returns: + Corner frequency in Hz, or None if not detectable + """ + if len(frequency) < 4: + return None + + # Work in log-log space + pos_mask = (frequency > 0) & (density > 0) + freq_pos = frequency[pos_mask] + dens_pos = density[pos_mask] + + if len(freq_pos) < 4: + return None + + log_f = np.log10(freq_pos) + log_d = np.log10(dens_pos) + + # Compute local slope using central differences (smoothed) + # Use a window of ~5 points for robustness + n = len(log_f) + slopes = np.zeros(n) + half_win = min(2, (n - 1) // 2) + + for i in range(half_win, n - half_win): + df = log_f[i + half_win] - log_f[i - half_win] + dd = log_d[i + half_win] - log_d[i - half_win] + if abs(df) > 1e-15: + slopes[i] = dd / df + + # Fill edges with nearest valid slope + slopes[:half_win] = slopes[half_win] + slopes[n - half_win :] = slopes[n - half_win - 1] + + # Find where slope crosses the threshold (-0.25) + # 1/f noise has slope ~ -0.5 in V/sqrt(Hz), white has slope ~ 0 + threshold = -0.25 + + for i in range(len(slopes) - 1): + if slopes[i] < threshold <= slopes[i + 1]: + # Interpolate + ds = slopes[i + 1] - slopes[i] + if abs(ds) < 1e-15: + return float(freq_pos[i]) + frac = (threshold - slopes[i]) / ds + log_corner = log_f[i] + frac * (log_f[i + 1] - log_f[i]) + return float(10.0**log_corner) + + return None + + +def compute_noise_metrics( + frequency: np.ndarray, + noise_signal: np.ndarray, + source_resistance: float = 50.0, +) -> dict: + """Comprehensive noise analysis report. + + Combines spectral density, spot noise at standard frequencies, total + integrated noise, noise figure, and 1/f corner estimation. + + Args: + frequency: Frequency array in Hz (may be complex; real part is used) + noise_signal: Complex noise signal from .raw file + source_resistance: Source impedance in ohms for noise figure (default 50) + + Returns: + Dict with spectral_density, spot_noise (at standard frequencies), + total_noise, noise_figure, flicker_corner_hz + """ + if len(frequency) < 2 or len(noise_signal) < 2: + return { + "spectral_density": compute_noise_spectral_density(frequency, noise_signal), + "spot_noise": {}, + "total_noise": compute_total_noise(frequency, noise_signal), + "noise_figure": compute_noise_figure(frequency, noise_signal, source_resistance), + "flicker_corner_hz": None, + } + + freq = np.real(frequency).astype(np.float64) + + # Spectral density + spectral = compute_noise_spectral_density(frequency, noise_signal) + + # Spot noise at standard frequencies + spot_freqs = [10.0, 100.0, 1000.0, 10000.0, 100000.0] + spot_labels = ["10Hz", "100Hz", "1kHz", "10kHz", "100kHz"] + spot_noise = {} + + f_min = float(np.min(freq)) + f_max = float(np.max(freq)) + + for label, sf in zip(spot_labels, spot_freqs): + if f_min <= sf <= f_max: + spot_noise[label] = compute_spot_noise(frequency, noise_signal, sf) + + # Total noise over full bandwidth + total = compute_total_noise(frequency, noise_signal) + + # Noise figure + nf = compute_noise_figure(frequency, noise_signal, source_resistance) + + # 1/f corner frequency estimation + sort_idx = np.argsort(freq) + sorted_freq = freq[sort_idx] + sorted_density = np.abs(noise_signal)[sort_idx] + flicker_corner = _estimate_flicker_corner(sorted_freq, sorted_density) + + return { + "spectral_density": spectral, + "spot_noise": spot_noise, + "total_noise": total, + "noise_figure": nf, + "flicker_corner_hz": float(flicker_corner) if flicker_corner is not None else None, + } diff --git a/src/mcp_ltspice/server.py b/src/mcp_ltspice/server.py index 377c64d..c81f78d 100644 --- a/src/mcp_ltspice/server.py +++ b/src/mcp_ltspice/server.py @@ -48,7 +48,19 @@ from .models import ( from .models import ( search_subcircuits as _search_subcircuits, ) -from .netlist import Netlist +from .netlist import ( + Netlist, + buck_converter, + colpitts_oscillator, + common_emitter_amplifier, + differential_amplifier, + h_bridge, + inverting_amplifier, + ldo_regulator, + non_inverting_amplifier, + rc_lowpass, + voltage_divider, +) from .optimizer import ( ComponentRange, OptimizationTarget, @@ -82,6 +94,7 @@ mcp = FastMCP( - Extract waveform data (voltages, currents) from simulation results - Analyze signals: FFT, THD, RMS, bandwidth, settling time - Create circuits from scratch using the netlist builder + - Create circuits from 10 pre-built templates (list_templates) - Modify component values in schematics programmatically - Browse LTspice's component library (6500+ symbols) - Search 2800+ SPICE models and subcircuits @@ -89,14 +102,15 @@ mcp = FastMCP( - Run design rule checks before simulation - Compare schematics to see what changed - Export waveform data to CSV + - Extract DC operating point (.op) and transfer function (.tf) data - Measure stability (gain/phase margins from AC loop gain) - Compute power and efficiency from voltage/current waveforms - Evaluate waveform math expressions (V*I, gain, dB, etc.) - Optimize component values to hit target specs automatically - Generate .asc schematic files (graphical format) - Run parameter sweeps, temperature sweeps, and Monte Carlo analysis + - Handle stepped simulations: list runs, extract per-run data - Parse Touchstone (.s2p) S-parameter files - - Use circuit templates: buck converter, LDO, diff amp, oscillator, H-bridge LTspice runs via Wine on Linux. Simulations execute in batch mode and results are parsed from binary .raw files. @@ -204,19 +218,32 @@ def get_waveform( raw_file_path: str, signal_names: list[str], max_points: int = 1000, + run: int | None = None, ) -> dict: """Extract waveform data from a .raw simulation results file. For transient analysis, returns time + voltage/current values. For AC analysis, returns frequency + magnitude(dB)/phase(degrees). + For stepped simulations (.step, .mc, .temp), specify `run` (1-based) + to extract a single run's data. Omit `run` to get all data combined. + Args: raw_file_path: Path to .raw file from simulation signal_names: Signal names to extract, e.g. ["V(out)", "I(R1)"] max_points: Maximum data points (downsampled if needed) + run: Run number (1-based) for stepped simulations (None = all data) """ raw = parse_raw_file(raw_file_path) + # Extract specific run if requested + if run is not None: + if not raw.is_stepped: + return {"error": "Not a stepped simulation - no multiple runs available"} + if run < 1 or run > raw.n_runs: + return {"error": f"Run {run} out of range (1..{raw.n_runs})"} + raw = raw.get_run_data(run) + x_axis = raw.get_time() x_name = "time" if x_axis is None: @@ -232,6 +259,8 @@ def get_waveform( "signals": {}, "total_points": total_points, "returned_points": 0, + "is_stepped": raw.is_stepped, + "n_runs": raw.n_runs, } if x_axis is not None: @@ -259,6 +288,42 @@ def get_waveform( return result +@mcp.tool() +def list_simulation_runs(raw_file_path: str) -> dict: + """List runs in a stepped simulation (.step, .mc, .temp). + + Returns run count and boundary information for multi-run .raw files. + + Args: + raw_file_path: Path to .raw file from simulation + """ + raw = parse_raw_file(raw_file_path) + + result = { + "is_stepped": raw.is_stepped, + "n_runs": raw.n_runs, + "total_points": raw.points, + "plotname": raw.plotname, + "variables": [{"name": v.name, "type": v.type} for v in raw.variables], + } + + if raw.is_stepped and raw.run_boundaries: + runs = [] + for i in range(raw.n_runs): + start, end = raw._run_slice(i + 1) + runs.append( + { + "run": i + 1, + "start_index": start, + "end_index": end, + "points": end - start, + } + ) + result["runs"] = runs + + return result + + @mcp.tool() def analyze_waveform( raw_file_path: str, @@ -510,6 +575,93 @@ def analyze_stability( return compute_stability_metrics(freq.real, signal) +# ============================================================================ +# DC OPERATING POINT & TRANSFER FUNCTION TOOLS +# ============================================================================ + + +@mcp.tool() +def get_operating_point(log_file_path: str) -> dict: + """Extract DC operating point results from a simulation log. + + The .op analysis computes all node voltages and branch currents + at the DC bias point. Results include device operating points + (transistor Gm, Id, Vgs, etc.) when available. + + Run a simulation with .op directive first, then pass the log file. + + Args: + log_file_path: Path to .log file from simulation + """ + log = parse_log(log_file_path) + + if not log.operating_point: + return { + "error": "No operating point data found in log. " + "Ensure the simulation uses a .op directive.", + "log_errors": log.errors, + } + + # Separate node voltages from branch currents/device params + voltages = {} + currents = {} + other = {} + for name, value in log.operating_point.items(): + if name.startswith("V(") or name.startswith("v("): + voltages[name] = value + elif name.startswith("I(") or name.startswith("i(") or name.startswith("Ix("): + currents[name] = value + else: + other[name] = value + + return { + "voltages": voltages, + "currents": currents, + "device_params": other, + "total_entries": len(log.operating_point), + } + + +@mcp.tool() +def get_transfer_function(log_file_path: str) -> dict: + """Extract .tf (transfer function) results from a simulation log. + + The .tf analysis computes: + - Transfer function (gain or transresistance) + - Input impedance at the source + - Output impedance at the output node + + Run a simulation with .tf directive first (e.g., ".tf V(out) V1"), + then pass the log file. + + Args: + log_file_path: Path to .log file from simulation + """ + log = parse_log(log_file_path) + + if not log.transfer_function: + return { + "error": "No transfer function data found in log. " + "Ensure the simulation uses a .tf directive, " + "e.g., '.tf V(out) V1'.", + "log_errors": log.errors, + } + + # Identify the specific components + result: dict = {"raw_data": log.transfer_function} + + for name, value in log.transfer_function.items(): + name_lower = name.lower() + if "transfer_function" in name_lower: + result["transfer_function"] = value + elif "output_impedance" in name_lower: + result["output_impedance_ohms"] = value + elif "input_impedance" in name_lower: + result["input_impedance_ohms"] = value + + return result + + # ============================================================================ # POWER ANALYSIS TOOLS # ============================================================================ @@ -1121,6 +1273,200 @@ def create_netlist( } +# ============================================================================ +# CIRCUIT TEMPLATE TOOLS +# ============================================================================ + +# Registry of netlist templates with parameter metadata +_TEMPLATES: dict[str, dict] = { + "voltage_divider": { + "func": voltage_divider, + "description": "Resistive voltage divider with .op or custom analysis", + "params": {"v_in": "5", "r1": "10k", "r2": "10k", "sim_type": "op"}, + }, + "rc_lowpass": { + "func": rc_lowpass, + "description": "RC lowpass filter with AC sweep", + "params": {"r": "1k", "c": "100n", "f_start": "1", "f_stop": "1meg"}, + }, + "inverting_amplifier": { + "func": inverting_amplifier, + "description": "Inverting op-amp (gain = -Rf/Rin), +/-15V supply", + "params": {"r_in": "10k", "r_f": "100k", "opamp_model": "LT1001"}, + }, + "non_inverting_amplifier": { + "func": non_inverting_amplifier, + "description": "Non-inverting op-amp (gain = 1 + Rf/Rin), +/-15V supply", + "params": {"r_in": "10k", "r_f": "100k", "opamp_model": "LT1001"}, + }, + "differential_amplifier": { + "func": differential_amplifier, + "description": "Diff amp: Vout = (R2/R1)*(V2-V1), +/-15V supply", + "params": { + "r1": "10k", + "r2": "10k", + "r3": "10k", + "r4": "10k", + "opamp_model": "LT1001", + }, + }, + "common_emitter_amplifier": { + "func": common_emitter_amplifier, + "description": "BJT common-emitter with voltage divider bias", + "params": { + "rc": "2.2k", + "rb1": "56k", + "rb2": "12k", + "re": "1k", + "cc1": "10u", + "cc2": "10u", + "ce": "47u", + "vcc": "12", + "bjt_model": "2N2222", + }, + }, + "buck_converter": { + "func": buck_converter, + "description": "Step-down DC-DC converter with MOSFET switch", + "params": { + "ind": "10u", + "c_out": "100u", + "r_load": "10", + "v_in": "12", + "duty_cycle": "0.5", + "freq": "100k", + "mosfet_model": "IRF540N", + "diode_model": "1N5819", + }, + }, + "ldo_regulator": { + "func": ldo_regulator, + "description": "LDO regulator: Vout = Vref * (1 + R1/R2)", + "params": { + "opamp_model": "LT1001", + "r1": "10k", + "r2": "10k", + "pass_transistor": "IRF9540N", + "v_in": "8", + "v_ref": "2.5", + }, + }, + "colpitts_oscillator": { + "func": colpitts_oscillator, + "description": "LC oscillator: f ~ 1/(2pi*sqrt(L*Cseries))", + "params": { + "ind": "1u", + "c1": "100p", + "c2": "100p", + "rb": "47k", + "rc": "1k", + "re": "470", + "vcc": "12", + "bjt_model": "2N2222", + }, + }, + "h_bridge": { + "func": h_bridge, + "description": "4-MOSFET H-bridge motor driver with dead time", + "params": { + "v_supply": "12", + "r_load": "10", + "mosfet_model": "IRF540N", + }, + }, +} + + +@mcp.tool() +def create_from_template( + template_name: str, + params: dict[str, str] | None = None, + output_path: str | None = None, +) -> dict: + """Create a circuit netlist from a pre-built template. + + Available templates: + - voltage_divider: params {v_in, r1, r2, sim_type} + - rc_lowpass: params {r, c, f_start, f_stop} + - inverting_amplifier: params {r_in, r_f, opamp_model} + - non_inverting_amplifier: params {r_in, r_f, opamp_model} + - differential_amplifier: params {r1, r2, r3, r4, opamp_model} + - common_emitter_amplifier: params {rc, rb1, rb2, re, cc1, cc2, ce, vcc, bjt_model} + - buck_converter: params {ind, c_out, r_load, v_in, duty_cycle, freq, mosfet_model, diode_model} + - ldo_regulator: params {opamp_model, r1, r2, pass_transistor, v_in, v_ref} + - colpitts_oscillator: params {ind, c1, c2, rb, rc, re, vcc, bjt_model} + - h_bridge: params {v_supply, r_load, mosfet_model} + + All parameter values are optional -- defaults are used if omitted. + + Args: + template_name: Template name from the list above + params: Optional dict of parameter overrides (all values as strings) + output_path: Where to save .cir file (None = auto in /tmp) + """ + template = _TEMPLATES.get(template_name) + if template is None: + return { + "error": f"Unknown template '{template_name}'", + "available_templates": [ + {"name": k, "description": v["description"], "params": v["params"]} + for k, v in _TEMPLATES.items() + ], + } + + # Build kwargs from params, converting duty_cycle to float for buck_converter + kwargs: dict = {} + if params: + for k, v in params.items(): + if k not in template["params"]: + return { + "error": f"Unknown parameter '{k}' for template '{template_name}'", + "valid_params": template["params"], + } + # duty_cycle needs to be float, not string + if k == "duty_cycle": + kwargs[k] = float(v) + else: + kwargs[k] = v + + nl = template["func"](**kwargs) + + if output_path is None: + output_path = str(Path(tempfile.gettempdir()) / f"{template_name}.cir") + + saved = nl.save(output_path) + + return { + "success": True, + "template": template_name, + "description": template["description"], + "output_path": str(saved), + "netlist_preview": nl.render(), + "component_count": len(nl.components), + "params_used": {**template["params"], **(params or {})}, + } + + +@mcp.tool() +def list_templates() -> dict: + """List all available circuit templates with their parameters and defaults. + + Returns template names, descriptions, and the parameters each accepts + with their default values. + """ + return { + "templates": [ + { + "name": name, + "description": info["description"], + "params": info["params"], + } + for name, info in _TEMPLATES.items() + ], + "total_count": len(_TEMPLATES), + } + + # ============================================================================ # LIBRARY & MODEL TOOLS # ============================================================================ @@ -1478,6 +1824,121 @@ Common issues: """ +@mcp.prompt() +def optimize_design( + circuit_type: str = "filter", + target_spec: str = "1kHz bandwidth", +) -> str: + """Guide through optimizing a circuit to meet target specifications. + + Args: + circuit_type: Type of circuit (filter, amplifier, regulator, oscillator) + target_spec: Target specification to achieve + """ + return f"""Optimize a {circuit_type} circuit to achieve: {target_spec} + +Workflow: +1. Start with a template: use list_templates to see available circuits +2. Create the initial circuit with create_from_template +3. Simulate and measure the current performance +4. Use optimize_circuit to automatically tune component values: + - Define target metrics (bandwidth, gain, settling time, etc.) + - Specify component ranges with preferred E-series values + - Let the optimizer iterate (typically 10-20 simulations) +5. Verify the optimized design with a full simulation +6. Run Monte Carlo (monte_carlo tool) to check yield with tolerances + +Tips: +- Start with reasonable initial values from the template +- Use E24 or E96 series for resistors/capacitors +- For filters: target bandwidth_hz metric +- For amplifiers: target gain_db and phase_margin_deg +- For regulators: target settling_time and peak_to_peak (ripple) +""" + + +@mcp.prompt() +def monte_carlo_analysis( + circuit_description: str = "RC filter", + n_runs: str = "100", +) -> str: + """Guide through Monte Carlo tolerance analysis. + + Args: + circuit_description: What circuit to analyze + n_runs: Number of Monte Carlo iterations + """ + return f"""Run Monte Carlo tolerance analysis on: {circuit_description} +Number of runs: {n_runs} + +Workflow: +1. Create or identify the netlist for your circuit +2. Use monte_carlo tool with component tolerances: + - Resistors: typically 1% (0.01) or 5% (0.05) + - Capacitors: typically 10% (0.1) or 20% (0.2) + - Inductors: typically 10% (0.1) +3. For each completed run, extract key metrics: + - Use get_waveform on each raw file + - Use analyze_waveform for RMS, peak-to-peak, etc. + - Use measure_bandwidth for filter circuits +4. Compute statistics across all runs: + - Mean and standard deviation of each metric + - Min/max (worst case) + - Yield: what percentage meet spec? + +Tips: +- Use list_simulation_runs to understand stepped data +- For stepped simulations, use get_waveform with run parameter +- Start with fewer runs (10-20) to verify setup, then scale up +- Set seed for reproducible results during development +- Typical component tolerances: + - Metal film resistors: 1% + - Ceramic capacitors: 10-20% + - Electrolytic capacitors: 20% + - Inductors: 10-20% +""" + + +@mcp.prompt() +def circuit_from_scratch( + description: str = "audio amplifier", +) -> str: + """Guide through creating a complete circuit from scratch. + + Args: + description: What circuit to build + """ + return f"""Build a complete circuit from scratch: {description} + +Approach 1 - Use a template (recommended for common circuits): +1. Use list_templates to see available circuit templates +2. Use create_from_template with custom parameters +3. Simulate with simulate_netlist +4. Analyze results with get_waveform and analyze_waveform + +Approach 2 - Build from components: +1. Use create_netlist to define components and connections +2. Use search_spice_models to find transistor/diode models +3. Use search_spice_subcircuits to find op-amp/IC models +4. Add simulation directives (.tran, .ac, .dc, .op, .tf) +5. Simulate and analyze + +Approach 3 - Graphical schematic: +1. Use generate_schematic for supported topologies (rc_lowpass, + voltage_divider, inverting_amp) +2. The .asc file can be opened in LTspice GUI for editing +3. Simulate with the simulate tool + +Verification workflow: +1. Run run_drc to check for design issues before simulating +2. Start with .op analysis to verify DC bias point +3. Run .tf analysis for gain and impedance +4. Run .ac analysis for frequency response +5. Run .tran analysis for time-domain behavior +6. Use diff_schematics to compare design iterations +""" + + # ============================================================================ # ENTRY POINT # ============================================================================ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c52d71d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,328 @@ +"""Shared fixtures for mcp-ltspice test suite. + +All fixtures produce synthetic data -- no LTspice or Wine required. +""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from mcp_ltspice.raw_parser import RawFile, Variable +from mcp_ltspice.schematic import Component, Flag, Schematic, Text, Wire + + +# --------------------------------------------------------------------------- +# Time-domain fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_rate() -> float: + """Default sample rate: 100 kHz.""" + return 100_000.0 + + +@pytest.fixture +def duration() -> float: + """Default signal duration: 10 ms (enough for 1 kHz signals).""" + return 0.01 + + +@pytest.fixture +def time_array(sample_rate, duration) -> np.ndarray: + """Uniformly spaced time array.""" + n = int(sample_rate * duration) + return np.linspace(0, duration, n, endpoint=False) + + +@pytest.fixture +def sine_1khz(time_array) -> np.ndarray: + """1 kHz sine wave, 1 V peak.""" + return np.sin(2 * np.pi * 1000 * time_array) + + +@pytest.fixture +def dc_signal() -> np.ndarray: + """Constant 3.3 V DC signal (1000 samples).""" + return np.full(1000, 3.3) + + +@pytest.fixture +def step_signal(time_array) -> np.ndarray: + """Unit step at t = duration/2 with exponential rise (tau = duration/10).""" + t = time_array + mid = t[-1] / 2 + tau = t[-1] / 10 + sig = np.where(t >= mid, 1.0 - np.exp(-(t - mid) / tau), 0.0) + return sig + + +# --------------------------------------------------------------------------- +# Frequency-domain / AC fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def ac_frequency() -> np.ndarray: + """Log-spaced frequency array: 1 Hz to 10 MHz, 500 points.""" + return np.logspace(0, 7, 500) + + +@pytest.fixture +def lowpass_response(ac_frequency) -> np.ndarray: + """First-order lowpass magnitude in dB (fc ~ 1 kHz).""" + fc = 1000.0 + mag = 1.0 / np.sqrt(1.0 + (ac_frequency / fc) ** 2) + return 20.0 * np.log10(mag) + + +@pytest.fixture +def lowpass_complex(ac_frequency) -> np.ndarray: + """First-order lowpass as complex transfer function (fc ~ 1 kHz).""" + fc = 1000.0 + s = 1j * ac_frequency / fc + return 1.0 / (1.0 + s) + + +# --------------------------------------------------------------------------- +# Stepped / multi-run fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def stepped_time() -> np.ndarray: + """Time axis for 3 runs, each 0..1 ms with 100 points per run.""" + runs = [] + for _ in range(3): + runs.append(np.linspace(0, 1e-3, 100, endpoint=False)) + return np.concatenate(runs) + + +@pytest.fixture +def stepped_data(stepped_time) -> np.ndarray: + """Two variables (time + V(out)) across 3 runs.""" + n = len(stepped_time) + data = np.zeros((2, n)) + data[0] = stepped_time + # Each run has a different amplitude sine wave + for run_idx in range(3): + start = run_idx * 100 + end = start + 100 + t_run = data[0, start:end] + data[1, start:end] = (run_idx + 1) * np.sin(2 * np.pi * 1000 * t_run) + return data + + +# --------------------------------------------------------------------------- +# Mock RawFile fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_rawfile(time_array, sine_1khz) -> RawFile: + """A simple transient-analysis RawFile with time and V(out).""" + n = len(time_array) + data = np.zeros((2, n)) + data[0] = time_array + data[1] = sine_1khz + return RawFile( + title="Test Circuit", + date="2026-01-01", + plotname="Transient Analysis", + flags=["real"], + variables=[ + Variable(0, "time", "time"), + Variable(1, "V(out)", "voltage"), + ], + points=n, + data=data, + ) + + +@pytest.fixture +def mock_rawfile_stepped(stepped_data) -> RawFile: + """A stepped RawFile with 3 runs.""" + n = stepped_data.shape[1] + return RawFile( + title="Stepped Sim", + date="2026-01-01", + plotname="Transient Analysis", + flags=["real", "stepped"], + variables=[ + Variable(0, "time", "time"), + Variable(1, "V(out)", "voltage"), + ], + points=n, + data=stepped_data, + n_runs=3, + run_boundaries=[0, 100, 200], + ) + + +@pytest.fixture +def mock_rawfile_ac(ac_frequency, lowpass_complex) -> RawFile: + """An AC-analysis RawFile with complex frequency-domain data.""" + n = len(ac_frequency) + data = np.zeros((2, n), dtype=np.complex128) + data[0] = ac_frequency.astype(np.complex128) + data[1] = lowpass_complex + return RawFile( + title="AC Sim", + date="2026-01-01", + plotname="AC Analysis", + flags=["complex"], + variables=[ + Variable(0, "frequency", "frequency"), + Variable(1, "V(out)", "voltage"), + ], + points=n, + data=data, + ) + + +# --------------------------------------------------------------------------- +# Netlist / Schematic string fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_netlist_str() -> str: + """A basic SPICE netlist string for an RC lowpass.""" + return ( + "* RC Lowpass Filter\n" + "V1 in 0 AC 1\n" + "R1 in out 1k\n" + "C1 out 0 100n\n" + ".ac dec 100 1 1meg\n" + ".backanno\n" + ".end\n" + ) + + +@pytest.fixture +def sample_asc_str() -> str: + """A minimal .asc schematic string for an RC lowpass.""" + return ( + "Version 4\n" + "SHEET 1 880 680\n" + "WIRE 80 96 176 96\n" + "WIRE 176 176 272 176\n" + "FLAG 80 176 0\n" + "FLAG 272 240 0\n" + "FLAG 176 176 out\n" + "SYMBOL voltage 80 80 R0\n" + "SYMATTR InstName V1\n" + "SYMATTR Value AC 1\n" + "SYMBOL res 160 80 R0\n" + "SYMATTR InstName R1\n" + "SYMATTR Value 1k\n" + "SYMBOL cap 256 176 R0\n" + "SYMATTR InstName C1\n" + "SYMATTR Value 100n\n" + "TEXT 80 296 Left 2 !.ac dec 100 1 1meg\n" + ) + + +# --------------------------------------------------------------------------- +# Touchstone fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_s2p_content() -> str: + """Synthetic .s2p file content (MA format, GHz).""" + lines = [ + "! Two-port S-parameter data", + "! Freq S11(mag) S11(ang) S21(mag) S21(ang) S12(mag) S12(ang) S22(mag) S22(ang)", + "# GHZ S MA R 50", + "1.0 0.5 -30 0.9 -10 0.1 170 0.4 -40", + "2.0 0.6 -50 0.8 -20 0.12 160 0.5 -60", + "3.0 0.7 -70 0.7 -30 0.15 150 0.55 -80", + ] + return "\n".join(lines) + "\n" + + +@pytest.fixture +def tmp_s2p_file(sample_s2p_content, tmp_path) -> Path: + """Write synthetic .s2p content to a temp file and return path.""" + p = tmp_path / "test.s2p" + p.write_text(sample_s2p_content) + return p + + +# --------------------------------------------------------------------------- +# Schematic object fixtures (for DRC and diff tests) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def valid_schematic() -> Schematic: + """A schematic with ground, components, wires, and sim directive.""" + sch = Schematic() + sch.flags = [ + Flag(80, 176, "0"), + Flag(272, 240, "0"), + Flag(176, 176, "out"), + ] + sch.wires = [ + Wire(80, 96, 176, 96), + Wire(176, 176, 272, 176), + ] + sch.components = [ + Component(name="V1", symbol="voltage", x=80, y=80, rotation=0, mirror=False, + attributes={"Value": "AC 1"}), + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + Component(name="C1", symbol="cap", x=256, y=176, rotation=0, mirror=False, + attributes={"Value": "100n"}), + ] + sch.texts = [ + Text(80, 296, ".ac dec 100 1 1meg", type="spice"), + ] + return sch + + +@pytest.fixture +def schematic_no_ground() -> Schematic: + """A schematic missing a ground node.""" + sch = Schematic() + sch.flags = [Flag(176, 176, "out")] + sch.wires = [Wire(80, 96, 176, 96)] + sch.components = [ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ] + sch.texts = [Text(80, 296, ".tran 10m", type="spice")] + return sch + + +@pytest.fixture +def schematic_no_sim() -> Schematic: + """A schematic missing a simulation directive.""" + sch = Schematic() + sch.flags = [Flag(80, 176, "0")] + sch.wires = [Wire(80, 96, 176, 96)] + sch.components = [ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ] + sch.texts = [] + return sch + + +@pytest.fixture +def schematic_duplicate_names() -> Schematic: + """A schematic with duplicate component names.""" + sch = Schematic() + sch.flags = [Flag(80, 176, "0")] + sch.wires = [Wire(80, 96, 176, 96)] + sch.components = [ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + Component(name="R1", symbol="res", x=320, y=80, rotation=0, mirror=False, + attributes={"Value": "2.2k"}), + ] + sch.texts = [Text(80, 296, ".tran 10m", type="spice")] + return sch diff --git a/tests/test_asc_generator.py b/tests/test_asc_generator.py new file mode 100644 index 0000000..264f474 --- /dev/null +++ b/tests/test_asc_generator.py @@ -0,0 +1,176 @@ +"""Tests for asc_generator module: pin positioning, schematic rendering, templates.""" + +import pytest + +from mcp_ltspice.asc_generator import ( + AscSchematic, + GRID, + _PIN_OFFSETS, + _rotate, + generate_inverting_amp, + generate_rc_lowpass, + generate_voltage_divider, + pin_position, +) + + +class TestPinPosition: + @pytest.mark.parametrize("symbol", ["voltage", "res", "cap", "ind"]) + def test_r0_returns_offset_plus_origin(self, symbol): + """At R0, pin position = origin + raw offset.""" + cx, cy = 160, 80 + for pin_idx in range(2): + px, py = pin_position(symbol, pin_idx, cx, cy, rotation=0) + offsets = _PIN_OFFSETS[symbol] + ox, oy = offsets[pin_idx] + assert px == cx + ox + assert py == cy + oy + + @pytest.mark.parametrize("symbol", ["voltage", "res", "cap", "ind"]) + def test_r90(self, symbol): + """R90 applies (px, py) -> (-py, px).""" + cx, cy = 160, 80 + for pin_idx in range(2): + px, py = pin_position(symbol, pin_idx, cx, cy, rotation=90) + offsets = _PIN_OFFSETS[symbol] + ox, oy = offsets[pin_idx] + assert px == cx + (-oy) + assert py == cy + ox + + @pytest.mark.parametrize("symbol", ["voltage", "res", "cap", "ind"]) + def test_r180(self, symbol): + """R180 applies (px, py) -> (-px, -py).""" + cx, cy = 160, 80 + for pin_idx in range(2): + px, py = pin_position(symbol, pin_idx, cx, cy, rotation=180) + offsets = _PIN_OFFSETS[symbol] + ox, oy = offsets[pin_idx] + assert px == cx + (-ox) + assert py == cy + (-oy) + + @pytest.mark.parametrize("symbol", ["voltage", "res", "cap", "ind"]) + def test_r270(self, symbol): + """R270 applies (px, py) -> (py, -px).""" + cx, cy = 160, 80 + for pin_idx in range(2): + px, py = pin_position(symbol, pin_idx, cx, cy, rotation=270) + offsets = _PIN_OFFSETS[symbol] + ox, oy = offsets[pin_idx] + assert px == cx + oy + assert py == cy + (-ox) + + def test_unknown_symbol_defaults(self): + """Unknown symbol uses default pin offsets.""" + px, py = pin_position("unknown", 0, 0, 0, rotation=0) + # Default is [(0, 0), (0, 80)] + assert (px, py) == (0, 0) + px2, py2 = pin_position("unknown", 1, 0, 0, rotation=0) + assert (px2, py2) == (0, 80) + + +class TestRotate: + def test_identity(self): + assert _rotate(10, 20, 0) == (10, 20) + + def test_90(self): + assert _rotate(10, 20, 90) == (-20, 10) + + def test_180(self): + assert _rotate(10, 20, 180) == (-10, -20) + + def test_270(self): + assert _rotate(10, 20, 270) == (20, -10) + + def test_invalid_rotation(self): + """Invalid rotation falls through to identity.""" + assert _rotate(10, 20, 45) == (10, 20) + + +class TestAscSchematicRender: + def test_version_header(self): + sch = AscSchematic() + text = sch.render() + assert text.startswith("Version 4\n") + + def test_sheet_dimensions(self): + sch = AscSchematic(sheet_w=1200, sheet_h=900) + text = sch.render() + assert "SHEET 1 1200 900" in text + + def test_wire_rendering(self): + sch = AscSchematic() + sch.add_wire(80, 96, 176, 96) + text = sch.render() + assert "WIRE 80 96 176 96" in text + + def test_component_rendering(self): + sch = AscSchematic() + sch.add_component("res", "R1", "1k", 160, 80) + text = sch.render() + assert "SYMBOL res 160 80 R0" in text + assert "SYMATTR InstName R1" in text + assert "SYMATTR Value 1k" in text + + def test_rotated_component(self): + sch = AscSchematic() + sch.add_component("res", "R1", "1k", 160, 80, rotation=90) + text = sch.render() + assert "SYMBOL res 160 80 R90" in text + + def test_ground_flag(self): + sch = AscSchematic() + sch.add_ground(80, 176) + text = sch.render() + assert "FLAG 80 176 0" in text + + def test_net_label(self): + sch = AscSchematic() + sch.add_net_label("out", 176, 176) + text = sch.render() + assert "FLAG 176 176 out" in text + + def test_directive_rendering(self): + sch = AscSchematic() + sch.add_directive(".tran 10m", 80, 300) + text = sch.render() + assert "TEXT 80 300 Left 2 !.tran 10m" in text + + def test_chaining(self): + sch = ( + AscSchematic() + .add_component("res", "R1", "1k", 160, 80) + .add_wire(80, 96, 176, 96) + .add_ground(80, 176) + ) + text = sch.render() + assert "SYMBOL" in text + assert "WIRE" in text + assert "FLAG" in text + + +class TestAscTemplates: + @pytest.mark.parametrize( + "factory", + [generate_rc_lowpass, generate_voltage_divider, generate_inverting_amp], + ) + def test_template_returns_schematic(self, factory): + sch = factory() + assert isinstance(sch, AscSchematic) + + @pytest.mark.parametrize( + "factory", + [generate_rc_lowpass, generate_voltage_divider, generate_inverting_amp], + ) + def test_template_nonempty(self, factory): + text = factory().render() + assert len(text) > 50 + assert "SYMBOL" in text + + @pytest.mark.parametrize( + "factory", + [generate_rc_lowpass, generate_voltage_divider, generate_inverting_amp], + ) + def test_template_has_expected_components(self, factory): + text = factory().render() + # All templates should have at least a res and a voltage source + assert "res" in text or "voltage" in text diff --git a/tests/test_diff.py b/tests/test_diff.py new file mode 100644 index 0000000..d8cbe2f --- /dev/null +++ b/tests/test_diff.py @@ -0,0 +1,234 @@ +"""Tests for diff module: schematic comparison.""" + +from pathlib import Path + +import pytest + +from mcp_ltspice.diff import ( + ComponentChange, + DirectiveChange, + SchematicDiff, + _diff_components, + _diff_directives, + _diff_nets, + _diff_wires, + diff_schematics, +) +from mcp_ltspice.schematic import Component, Flag, Schematic, Text, Wire, write_schematic + + +def _make_schematic(**kwargs) -> Schematic: + """Helper to build a Schematic with overrides.""" + sch = Schematic() + sch.components = kwargs.get("components", []) + sch.wires = kwargs.get("wires", []) + sch.flags = kwargs.get("flags", []) + sch.texts = kwargs.get("texts", []) + return sch + + +class TestDiffComponents: + def test_added_component(self): + sch_a = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ]) + sch_b = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + Component(name="C1", symbol="cap", x=256, y=176, rotation=0, mirror=False, + attributes={"Value": "100n"}), + ]) + changes = _diff_components(sch_a, sch_b) + added = [c for c in changes if c.change_type == "added"] + assert len(added) == 1 + assert added[0].name == "C1" + + def test_removed_component(self): + sch_a = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + Component(name="R2", symbol="res", x=320, y=80, rotation=0, mirror=False, + attributes={"Value": "2.2k"}), + ]) + sch_b = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ]) + changes = _diff_components(sch_a, sch_b) + removed = [c for c in changes if c.change_type == "removed"] + assert len(removed) == 1 + assert removed[0].name == "R2" + + def test_modified_value(self): + sch_a = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ]) + sch_b = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "2.2k"}), + ]) + changes = _diff_components(sch_a, sch_b) + modified = [c for c in changes if c.change_type == "modified"] + assert len(modified) == 1 + assert modified[0].old_value == "1k" + assert modified[0].new_value == "2.2k" + + def test_moved_component(self): + sch_a = _make_schematic(components=[ + Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ]) + sch_b = _make_schematic(components=[ + Component(name="R1", symbol="res", x=320, y=160, rotation=0, mirror=False, + attributes={"Value": "1k"}), + ]) + changes = _diff_components(sch_a, sch_b) + assert len(changes) == 1 + assert changes[0].moved is True + + def test_no_changes(self): + comp = Component(name="R1", symbol="res", x=160, y=80, rotation=0, mirror=False, + attributes={"Value": "1k"}) + sch = _make_schematic(components=[comp]) + changes = _diff_components(sch, sch) + assert len(changes) == 0 + + +class TestDiffDirectives: + def test_added_directive(self): + sch_a = _make_schematic(texts=[ + Text(80, 296, ".tran 10m", type="spice"), + ]) + sch_b = _make_schematic(texts=[ + Text(80, 296, ".tran 10m", type="spice"), + Text(80, 320, ".meas tran vmax MAX V(out)", type="spice"), + ]) + changes = _diff_directives(sch_a, sch_b) + added = [c for c in changes if c.change_type == "added"] + assert len(added) == 1 + + def test_removed_directive(self): + sch_a = _make_schematic(texts=[ + Text(80, 296, ".tran 10m", type="spice"), + Text(80, 320, ".op", type="spice"), + ]) + sch_b = _make_schematic(texts=[ + Text(80, 296, ".tran 10m", type="spice"), + ]) + changes = _diff_directives(sch_a, sch_b) + removed = [c for c in changes if c.change_type == "removed"] + assert len(removed) == 1 + + def test_modified_directive(self): + sch_a = _make_schematic(texts=[ + Text(80, 296, ".tran 10m", type="spice"), + ]) + sch_b = _make_schematic(texts=[ + Text(80, 296, ".tran 50m", type="spice"), + ]) + changes = _diff_directives(sch_a, sch_b) + modified = [c for c in changes if c.change_type == "modified"] + assert len(modified) == 1 + assert modified[0].old_text == ".tran 10m" + assert modified[0].new_text == ".tran 50m" + + +class TestDiffNets: + def test_added_nets(self): + sch_a = _make_schematic(flags=[Flag(80, 176, "0")]) + sch_b = _make_schematic(flags=[Flag(80, 176, "0"), Flag(176, 176, "out")]) + added, removed = _diff_nets(sch_a, sch_b) + assert "out" in added + assert len(removed) == 0 + + def test_removed_nets(self): + sch_a = _make_schematic(flags=[Flag(80, 176, "0"), Flag(176, 176, "out")]) + sch_b = _make_schematic(flags=[Flag(80, 176, "0")]) + added, removed = _diff_nets(sch_a, sch_b) + assert "out" in removed + assert len(added) == 0 + + +class TestDiffWires: + def test_added_wires(self): + sch_a = _make_schematic(wires=[Wire(80, 96, 176, 96)]) + sch_b = _make_schematic(wires=[ + Wire(80, 96, 176, 96), + Wire(176, 176, 272, 176), + ]) + added, removed = _diff_wires(sch_a, sch_b) + assert added == 1 + assert removed == 0 + + def test_removed_wires(self): + sch_a = _make_schematic(wires=[ + Wire(80, 96, 176, 96), + Wire(176, 176, 272, 176), + ]) + sch_b = _make_schematic(wires=[Wire(80, 96, 176, 96)]) + added, removed = _diff_wires(sch_a, sch_b) + assert added == 0 + assert removed == 1 + + +class TestSchematicDiff: + def test_has_changes_false(self): + diff = SchematicDiff() + assert diff.has_changes is False + + def test_has_changes_true(self): + diff = SchematicDiff(wires_added=1) + assert diff.has_changes is True + + def test_summary_no_changes(self): + diff = SchematicDiff() + assert "No changes" in diff.summary() + + def test_to_dict(self): + diff = SchematicDiff( + component_changes=[ + ComponentChange(name="R1", change_type="modified", + old_value="1k", new_value="2.2k") + ] + ) + d = diff.to_dict() + assert d["has_changes"] is True + assert len(d["component_changes"]) == 1 + + +class TestDiffSchematicsIntegration: + """Write two schematics to disk and compare them end-to-end.""" + + def test_full_diff(self, valid_schematic, tmp_path): + # Create "before" schematic + path_a = tmp_path / "before.asc" + write_schematic(valid_schematic, path_a) + + # Create "after" schematic with a modified R1 value + modified = Schematic() + modified.flags = list(valid_schematic.flags) + modified.wires = list(valid_schematic.wires) + modified.texts = list(valid_schematic.texts) + modified.components = [] + for comp in valid_schematic.components: + if comp.name == "R1": + new_comp = Component( + name=comp.name, symbol=comp.symbol, + x=comp.x, y=comp.y, rotation=comp.rotation, mirror=comp.mirror, + attributes={"Value": "4.7k"}, + ) + modified.components.append(new_comp) + else: + modified.components.append(comp) + + path_b = tmp_path / "after.asc" + write_schematic(modified, path_b) + + diff = diff_schematics(path_a, path_b) + assert diff.has_changes + r1_changes = [c for c in diff.component_changes if c.name == "R1"] + assert len(r1_changes) == 1 + assert r1_changes[0].old_value == "1k" + assert r1_changes[0].new_value == "4.7k" diff --git a/tests/test_drc.py b/tests/test_drc.py new file mode 100644 index 0000000..4fc3638 --- /dev/null +++ b/tests/test_drc.py @@ -0,0 +1,130 @@ +"""Tests for drc module: design rule checks on schematic objects.""" + +import tempfile +from pathlib import Path + +import pytest + +from mcp_ltspice.drc import ( + DRCResult, + DRCViolation, + Severity, + _check_duplicate_names, + _check_ground, + _check_simulation_directive, +) +from mcp_ltspice.schematic import Component, Flag, Schematic, Text, Wire, write_schematic + + +def _run_single_check(check_fn, schematic: Schematic) -> DRCResult: + """Run a single DRC check function and return results.""" + result = DRCResult() + check_fn(schematic, result) + return result + + +class TestGroundCheck: + def test_missing_ground_detected(self, schematic_no_ground): + result = _run_single_check(_check_ground, schematic_no_ground) + assert not result.passed + assert any(v.rule == "NO_GROUND" for v in result.violations) + + def test_ground_present(self, valid_schematic): + result = _run_single_check(_check_ground, valid_schematic) + assert result.passed + assert len(result.violations) == 0 + + +class TestSimDirectiveCheck: + def test_missing_sim_directive_detected(self, schematic_no_sim): + result = _run_single_check(_check_simulation_directive, schematic_no_sim) + assert not result.passed + assert any(v.rule == "NO_SIM_DIRECTIVE" for v in result.violations) + + def test_sim_directive_present(self, valid_schematic): + result = _run_single_check(_check_simulation_directive, valid_schematic) + assert result.passed + + +class TestDuplicateNameCheck: + def test_duplicate_names_detected(self, schematic_duplicate_names): + result = _run_single_check(_check_duplicate_names, schematic_duplicate_names) + assert not result.passed + assert any(v.rule == "DUPLICATE_NAME" for v in result.violations) + + def test_unique_names_pass(self, valid_schematic): + result = _run_single_check(_check_duplicate_names, valid_schematic) + assert result.passed + + +class TestDRCResult: + def test_passed_when_no_errors(self): + result = DRCResult() + result.violations.append( + DRCViolation(rule="TEST", severity=Severity.WARNING, message="warning only") + ) + assert result.passed # Warnings don't cause failure + + def test_failed_when_errors(self): + result = DRCResult() + result.violations.append( + DRCViolation(rule="TEST", severity=Severity.ERROR, message="error") + ) + assert not result.passed + + def test_summary_no_violations(self): + result = DRCResult(checks_run=5) + assert "passed" in result.summary().lower() + + def test_summary_with_errors(self): + result = DRCResult(checks_run=5) + result.violations.append( + DRCViolation(rule="TEST", severity=Severity.ERROR, message="error") + ) + assert "FAILED" in result.summary() + + def test_to_dict(self): + result = DRCResult(checks_run=3) + result.violations.append( + DRCViolation(rule="NO_GROUND", severity=Severity.ERROR, message="No ground") + ) + d = result.to_dict() + assert d["passed"] is False + assert d["error_count"] == 1 + assert len(d["violations"]) == 1 + + def test_errors_and_warnings_properties(self): + result = DRCResult() + result.violations.append( + DRCViolation(rule="E1", severity=Severity.ERROR, message="err") + ) + result.violations.append( + DRCViolation(rule="W1", severity=Severity.WARNING, message="warn") + ) + result.violations.append( + DRCViolation(rule="I1", severity=Severity.INFO, message="info") + ) + assert len(result.errors) == 1 + assert len(result.warnings) == 1 + + +class TestFullDRC: + """Integration test: write a schematic to disk and run the full DRC pipeline.""" + + def test_valid_schematic_passes(self, valid_schematic, tmp_path): + """A valid schematic should pass DRC with no errors.""" + from mcp_ltspice.drc import run_drc + + path = tmp_path / "valid.asc" + write_schematic(valid_schematic, path) + result = run_drc(path) + # May have warnings (floating nodes etc) but no errors + assert len(result.errors) == 0 + + def test_no_ground_fails(self, schematic_no_ground, tmp_path): + from mcp_ltspice.drc import run_drc + + path = tmp_path / "no_ground.asc" + write_schematic(schematic_no_ground, path) + result = run_drc(path) + assert any(v.rule == "NO_GROUND" for v in result.errors) diff --git a/tests/test_netlist.py b/tests/test_netlist.py new file mode 100644 index 0000000..22211c5 --- /dev/null +++ b/tests/test_netlist.py @@ -0,0 +1,197 @@ +"""Tests for netlist module: builder pattern, rendering, template functions.""" + +import pytest + +from mcp_ltspice.netlist import ( + Netlist, + buck_converter, + colpitts_oscillator, + common_emitter_amplifier, + differential_amplifier, + h_bridge, + inverting_amplifier, + ldo_regulator, + non_inverting_amplifier, + rc_lowpass, + voltage_divider, +) + + +class TestNetlistBuilder: + def test_add_resistor(self): + n = Netlist().add_resistor("R1", "in", "out", "10k") + assert len(n.components) == 1 + assert n.components[0].name == "R1" + assert n.components[0].value == "10k" + + def test_add_capacitor(self): + n = Netlist().add_capacitor("C1", "out", "0", "100n") + assert len(n.components) == 1 + assert n.components[0].value == "100n" + + def test_add_inductor(self): + n = Netlist().add_inductor("L1", "a", "b", "10u", series_resistance="0.1") + assert "Rser=0.1" in n.components[0].params + + def test_chaining(self): + """Builder methods return self for chaining.""" + n = ( + Netlist("Test") + .add_resistor("R1", "a", "b", "1k") + .add_capacitor("C1", "b", "0", "1n") + ) + assert len(n.components) == 2 + + def test_add_voltage_source_dc(self): + n = Netlist().add_voltage_source("V1", "in", "0", dc="5") + assert "5" in n.components[0].value + + def test_add_voltage_source_ac(self): + n = Netlist().add_voltage_source("V1", "in", "0", ac="1") + assert "AC 1" in n.components[0].value + + def test_add_voltage_source_pulse(self): + n = Netlist().add_voltage_source( + "V1", "g", "0", pulse=("0", "5", "0", "1n", "1n", "5u", "10u") + ) + rendered = n.render() + assert "PULSE(" in rendered + + def test_add_voltage_source_sin(self): + n = Netlist().add_voltage_source( + "V1", "in", "0", sin=("0", "1", "1k") + ) + rendered = n.render() + assert "SIN(" in rendered + + def test_add_directive(self): + n = Netlist().add_directive(".tran 10m") + assert ".tran 10m" in n.directives + + def test_add_meas(self): + n = Netlist().add_meas("tran", "vmax", "MAX V(out)") + assert any("vmax" in d for d in n.directives) + + +class TestNetlistRender: + def test_render_contains_title(self): + n = Netlist("My Circuit") + text = n.render() + assert "* My Circuit" in text + + def test_render_contains_components(self): + n = ( + Netlist() + .add_resistor("R1", "in", "out", "10k") + .add_capacitor("C1", "out", "0", "100n") + ) + text = n.render() + assert "R1 in out 10k" in text + assert "C1 out 0 100n" in text + + def test_render_contains_backanno_and_end(self): + n = Netlist() + text = n.render() + assert ".backanno" in text + assert ".end" in text + + def test_render_includes_directive(self): + n = Netlist().add_directive(".ac dec 100 1 1meg") + text = n.render() + assert ".ac dec 100 1 1meg" in text + + def test_render_includes_comment(self): + n = Netlist().add_comment("Test comment") + text = n.render() + assert "* Test comment" in text + + def test_render_includes_lib(self): + n = Netlist().add_lib("LT1001") + text = n.render() + assert ".lib LT1001" in text + + +class TestTemplateNetlists: + """All template functions should return valid Netlist objects.""" + + @pytest.mark.parametrize( + "factory", + [ + voltage_divider, + rc_lowpass, + inverting_amplifier, + non_inverting_amplifier, + differential_amplifier, + common_emitter_amplifier, + buck_converter, + ldo_regulator, + colpitts_oscillator, + h_bridge, + ], + ) + def test_template_returns_netlist(self, factory): + n = factory() + assert isinstance(n, Netlist) + + @pytest.mark.parametrize( + "factory", + [ + voltage_divider, + rc_lowpass, + inverting_amplifier, + non_inverting_amplifier, + differential_amplifier, + common_emitter_amplifier, + buck_converter, + ldo_regulator, + colpitts_oscillator, + h_bridge, + ], + ) + def test_template_has_backanno_and_end(self, factory): + text = factory().render() + assert ".backanno" in text + assert ".end" in text + + @pytest.mark.parametrize( + "factory", + [ + voltage_divider, + rc_lowpass, + inverting_amplifier, + non_inverting_amplifier, + differential_amplifier, + common_emitter_amplifier, + buck_converter, + ldo_regulator, + colpitts_oscillator, + h_bridge, + ], + ) + def test_template_has_components(self, factory): + n = factory() + assert len(n.components) > 0 + + @pytest.mark.parametrize( + "factory", + [ + voltage_divider, + rc_lowpass, + inverting_amplifier, + non_inverting_amplifier, + differential_amplifier, + common_emitter_amplifier, + buck_converter, + ldo_regulator, + colpitts_oscillator, + h_bridge, + ], + ) + def test_template_has_sim_directive(self, factory): + n = factory() + # Should have at least one directive starting with a sim type + sim_types = [".tran", ".ac", ".dc", ".op", ".noise", ".tf"] + text = n.render() + assert any(sim in text.lower() for sim in sim_types), ( + f"No simulation directive found in {factory.__name__}" + ) diff --git a/tests/test_optimizer_helpers.py b/tests/test_optimizer_helpers.py new file mode 100644 index 0000000..5b7c093 --- /dev/null +++ b/tests/test_optimizer_helpers.py @@ -0,0 +1,87 @@ +"""Tests for optimizer module helpers: snap_to_preferred, format_engineering.""" + +import pytest + +from mcp_ltspice.optimizer import format_engineering, snap_to_preferred + + +class TestSnapToPreferred: + def test_e12_exact_match(self): + """A value that is already an E12 value should snap to itself.""" + assert snap_to_preferred(4700.0, "E12") == pytest.approx(4700.0, rel=0.01) + + def test_e12_near_value(self): + """4800 should snap to 4700 (E12).""" + result = snap_to_preferred(4800.0, "E12") + assert result == pytest.approx(4700.0, rel=0.05) + + def test_e24_finer_resolution(self): + """E24 has 5.1, so 5050 should snap to 5100.""" + result = snap_to_preferred(5050.0, "E24") + assert result == pytest.approx(5100.0, rel=0.05) + + def test_e96_precision(self): + """E96 should snap to a value very close to the input.""" + result = snap_to_preferred(4750.0, "E96") + assert result == pytest.approx(4750.0, rel=0.03) + + def test_zero_value(self): + """Zero should snap to the smallest E-series value.""" + result = snap_to_preferred(0.0, "E12") + assert result > 0 + + def test_negative_value(self): + """Negative value should snap to the smallest E-series value.""" + result = snap_to_preferred(-100.0, "E12") + assert result > 0 + + def test_sub_ohm(self): + """Small values (e.g., 0.47 ohms) should snap correctly.""" + result = snap_to_preferred(0.5, "E12") + assert result == pytest.approx(0.47, rel=0.1) + + def test_megohm_range(self): + """Large values should snap correctly across decades.""" + result = snap_to_preferred(2_200_000.0, "E12") + assert result == pytest.approx(2_200_000.0, rel=0.05) + + def test_unknown_series_defaults_to_e12(self): + """Unknown series name should fall back to E12.""" + result = snap_to_preferred(4800.0, "E6") + assert result == pytest.approx(4700.0, rel=0.05) + + +class TestFormatEngineering: + def test_10k(self): + assert format_engineering(10_000) == "10k" + + def test_1u(self): + assert format_engineering(0.000001) == "1u" + + def test_4_7k(self): + assert format_engineering(4700) == "4.7k" + + def test_zero(self): + assert format_engineering(0) == "0" + + def test_1_5(self): + """Values in the unity range should have no suffix.""" + result = format_engineering(1.5) + assert result == "1.5" + + def test_negative(self): + result = format_engineering(-4700) + assert result.startswith("-") + assert "4.7k" in result + + def test_picofarad(self): + result = format_engineering(100e-12) + assert "100p" in result + + def test_milliamp(self): + result = format_engineering(0.010) + assert "10m" in result + + def test_large_value(self): + result = format_engineering(1e9) + assert "G" in result or "1e" in result diff --git a/tests/test_power_analysis.py b/tests/test_power_analysis.py new file mode 100644 index 0000000..aaad800 --- /dev/null +++ b/tests/test_power_analysis.py @@ -0,0 +1,123 @@ +"""Tests for power_analysis module: average power, efficiency, power factor.""" + +import numpy as np +import pytest + +from mcp_ltspice.power_analysis import ( + compute_average_power, + compute_efficiency, + compute_instantaneous_power, + compute_power_metrics, +) + + +class TestComputeAveragePower: + def test_dc_power(self): + """DC: P = V * I exactly.""" + n = 1000 + t = np.linspace(0, 1.0, n) + v = np.full(n, 5.0) + i = np.full(n, 2.0) + p = compute_average_power(t, v, i) + assert p == pytest.approx(10.0, rel=1e-6) + + def test_ac_in_phase(self): + """In-phase AC: P_avg = Vpk*Ipk/2.""" + n = 10000 + t = np.linspace(0, 0.01, n, endpoint=False) # 1 full period at 100 Hz + freq = 100.0 + Vpk = 10.0 + Ipk = 2.0 + v = Vpk * np.sin(2 * np.pi * freq * t) + i = Ipk * np.sin(2 * np.pi * freq * t) + p = compute_average_power(t, v, i) + expected = Vpk * Ipk / 2.0 # = Vrms * Irms + assert p == pytest.approx(expected, rel=0.02) + + def test_ac_quadrature(self): + """90-degree phase shift: P_avg ~ 0 (reactive power only).""" + n = 10000 + t = np.linspace(0, 0.01, n, endpoint=False) + freq = 100.0 + v = np.sin(2 * np.pi * freq * t) + i = np.cos(2 * np.pi * freq * t) # 90 deg shifted + p = compute_average_power(t, v, i) + assert p == pytest.approx(0.0, abs=0.01) + + def test_short_signal(self): + assert compute_average_power(np.array([0.0]), np.array([5.0]), np.array([2.0])) == 0.0 + + +class TestComputeEfficiency: + def test_known_efficiency(self): + """Input 10W, output 8W -> 80% efficiency.""" + n = 1000 + t = np.linspace(0, 1.0, n) + vin = np.full(n, 10.0) + iin = np.full(n, 1.0) # 10W input + vout = np.full(n, 8.0) + iout = np.full(n, 1.0) # 8W output + + result = compute_efficiency(t, vin, iin, vout, iout) + assert result["efficiency_percent"] == pytest.approx(80.0, rel=0.01) + assert result["input_power_watts"] == pytest.approx(10.0, rel=0.01) + assert result["output_power_watts"] == pytest.approx(8.0, rel=0.01) + assert result["power_dissipated_watts"] == pytest.approx(2.0, rel=0.01) + + def test_zero_input_power(self): + """Zero input -> 0% efficiency (avoid division by zero).""" + n = 100 + t = np.linspace(0, 1.0, n) + zeros = np.zeros(n) + result = compute_efficiency(t, zeros, zeros, zeros, zeros) + assert result["efficiency_percent"] == 0.0 + + +class TestPowerFactor: + def test_dc_power_factor(self): + """DC signals (in phase) should have PF = 1.0.""" + n = 1000 + t = np.linspace(0, 1.0, n) + v = np.full(n, 5.0) + i = np.full(n, 2.0) + result = compute_power_metrics(t, v, i) + assert result["power_factor"] == pytest.approx(1.0, rel=0.01) + + def test_ac_in_phase_power_factor(self): + """In-phase AC should have PF ~ 1.0.""" + n = 10000 + t = np.linspace(0, 0.01, n, endpoint=False) + freq = 100.0 + v = np.sin(2 * np.pi * freq * t) + i = np.sin(2 * np.pi * freq * t) + result = compute_power_metrics(t, v, i) + assert result["power_factor"] == pytest.approx(1.0, rel=0.05) + + def test_ac_quadrature_power_factor(self): + """90-degree phase shift -> PF ~ 0.""" + n = 10000 + t = np.linspace(0, 0.01, n, endpoint=False) + freq = 100.0 + v = np.sin(2 * np.pi * freq * t) + i = np.cos(2 * np.pi * freq * t) + result = compute_power_metrics(t, v, i) + assert result["power_factor"] == pytest.approx(0.0, abs=0.05) + + def test_empty_signals(self): + result = compute_power_metrics(np.array([]), np.array([]), np.array([])) + assert result["power_factor"] == 0.0 + + +class TestInstantaneousPower: + def test_element_wise(self): + v = np.array([1.0, 2.0, 3.0]) + i = np.array([0.5, 1.0, 1.5]) + p = compute_instantaneous_power(v, i) + np.testing.assert_array_almost_equal(p, [0.5, 2.0, 4.5]) + + def test_complex_uses_real(self): + """Should use real parts only.""" + v = np.array([3.0 + 4j]) + i = np.array([2.0 + 1j]) + p = compute_instantaneous_power(v, i) + assert p[0] == pytest.approx(6.0) # 3 * 2 diff --git a/tests/test_raw_parser.py b/tests/test_raw_parser.py new file mode 100644 index 0000000..ffdd47d --- /dev/null +++ b/tests/test_raw_parser.py @@ -0,0 +1,101 @@ +"""Tests for raw_parser module: run boundaries, variable lookup, run slicing.""" + +import numpy as np +import pytest + +from mcp_ltspice.raw_parser import RawFile, Variable, _detect_run_boundaries + + +class TestDetectRunBoundaries: + def test_single_run(self): + """Monotonically increasing time -> single run starting at 0.""" + x = np.linspace(0, 1e-3, 500) + boundaries = _detect_run_boundaries(x) + assert boundaries == [0] + + def test_multi_run(self): + """Three runs: time resets to near-zero at each boundary.""" + run1 = np.linspace(0, 1e-3, 100) + run2 = np.linspace(0, 1e-3, 100) + run3 = np.linspace(0, 1e-3, 100) + x = np.concatenate([run1, run2, run3]) + boundaries = _detect_run_boundaries(x) + assert len(boundaries) == 3 + assert boundaries[0] == 0 + assert boundaries[1] == 100 + assert boundaries[2] == 200 + + def test_complex_ac(self): + """AC analysis with complex frequency axis that resets.""" + run1 = np.logspace(0, 6, 50).astype(np.complex128) + run2 = np.logspace(0, 6, 50).astype(np.complex128) + x = np.concatenate([run1, run2]) + boundaries = _detect_run_boundaries(x) + assert len(boundaries) == 2 + assert boundaries[0] == 0 + assert boundaries[1] == 50 + + def test_single_point(self): + """Single data point -> one run.""" + boundaries = _detect_run_boundaries(np.array([0.0])) + assert boundaries == [0] + + +class TestRawFileGetVariable: + def test_exact_match(self, mock_rawfile): + """Exact name match returns correct data.""" + result = mock_rawfile.get_variable("V(out)") + assert result is not None + assert len(result) == mock_rawfile.points + + def test_case_insensitive(self, mock_rawfile): + """Variable lookup is case-insensitive (partial match).""" + result = mock_rawfile.get_variable("v(out)") + assert result is not None + + def test_partial_match(self, mock_rawfile): + """Substring match should work: 'out' matches 'V(out)'.""" + result = mock_rawfile.get_variable("out") + assert result is not None + + def test_missing_variable(self, mock_rawfile): + """Non-existent variable returns None.""" + result = mock_rawfile.get_variable("V(nonexistent)") + assert result is None + + def test_get_time(self, mock_rawfile): + result = mock_rawfile.get_time() + assert result is not None + assert len(result) == mock_rawfile.points + + +class TestRawFileRunData: + def test_get_run_data_slicing(self, mock_rawfile_stepped): + """Extracting a single run produces correct point count.""" + run0 = mock_rawfile_stepped.get_run_data(0) + assert run0.points == 100 + assert run0.n_runs == 1 + assert run0.is_stepped is False + + def test_get_run_data_values(self, mock_rawfile_stepped): + """Each run has the expected amplitude scaling.""" + for i in range(3): + run = mock_rawfile_stepped.get_run_data(i) + sig = run.get_variable("V(out)") + # Peak amplitude should be approximately (i+1) + assert float(np.max(np.abs(sig))) == pytest.approx(i + 1, rel=0.1) + + def test_is_stepped(self, mock_rawfile_stepped, mock_rawfile): + assert mock_rawfile_stepped.is_stepped is True + assert mock_rawfile.is_stepped is False + + def test_get_variable_with_run(self, mock_rawfile_stepped): + """get_variable with run= parameter slices correctly.""" + v_run1 = mock_rawfile_stepped.get_variable("V(out)", run=1) + assert v_run1 is not None + assert len(v_run1) == 100 + + def test_non_stepped_get_run_data(self, mock_rawfile): + """Getting run data from non-stepped file returns self.""" + run = mock_rawfile.get_run_data(0) + assert run.points == mock_rawfile.points diff --git a/tests/test_stability.py b/tests/test_stability.py new file mode 100644 index 0000000..1d9679e --- /dev/null +++ b/tests/test_stability.py @@ -0,0 +1,108 @@ +"""Tests for stability module: gain margin, phase margin from loop gain data.""" + +import numpy as np +import pytest + +from mcp_ltspice.stability import ( + compute_gain_margin, + compute_phase_margin, + compute_stability_metrics, +) + + +def _second_order_system(freq, wn=1000.0, zeta=0.3): + """Create a 2nd-order underdamped system: H(s) = wn^2 / (s^2 + 2*zeta*wn*s + wn^2). + + Returns complex loop gain at the given frequencies. + """ + s = 1j * 2 * np.pi * freq + return wn**2 / (s**2 + 2 * zeta * wn * s + wn**2) + + +class TestGainMargin: + def test_third_order_system(self): + """A 3rd-order system crosses -180 phase and has finite gain margin.""" + freq = np.logspace(0, 6, 10000) + # Three-pole system: K / ((s/w1 + 1) * (s/w2 + 1) * (s/w3 + 1)) + # Phase goes from 0 to -270, so it definitely crosses -180 + w1 = 2 * np.pi * 100 + w2 = 2 * np.pi * 1000 + w3 = 2 * np.pi * 10000 + K = 100.0 # enough gain to have a gain crossover + s = 1j * 2 * np.pi * freq + loop_gain = K / ((s / w1 + 1) * (s / w2 + 1) * (s / w3 + 1)) + + result = compute_gain_margin(freq, loop_gain) + assert result["gain_margin_db"] is not None + assert result["is_stable"] is True + assert result["gain_margin_db"] > 0 + assert result["phase_crossover_freq_hz"] is not None + + def test_no_phase_crossover(self): + """A simple first-order system never reaches -180 phase, so GM is infinite.""" + freq = np.logspace(0, 6, 1000) + s = 1j * 2 * np.pi * freq + # First-order: 1/(1+s/wn) -- phase goes from 0 to -90 + loop_gain = 1.0 / (1 + s / (2 * np.pi * 1000)) + result = compute_gain_margin(freq, loop_gain) + assert result["gain_margin_db"] == float("inf") + assert result["is_stable"] is True + + def test_short_input(self): + result = compute_gain_margin(np.array([1.0]), np.array([1.0 + 0j])) + assert result["gain_margin_db"] is None + + +class TestPhaseMargin: + def test_third_order_system(self): + """A 3rd-order system with sufficient gain should have measurable phase margin.""" + freq = np.logspace(0, 6, 10000) + w1 = 2 * np.pi * 100 + w2 = 2 * np.pi * 1000 + w3 = 2 * np.pi * 10000 + K = 100.0 + s = 1j * 2 * np.pi * freq + loop_gain = K / ((s / w1 + 1) * (s / w2 + 1) * (s / w3 + 1)) + + result = compute_phase_margin(freq, loop_gain) + assert result["phase_margin_deg"] is not None + assert result["is_stable"] is True + assert result["phase_margin_deg"] > 0 + + def test_all_gain_below_0db(self): + """If gain is always below 0 dB, phase margin is infinite (system is stable).""" + freq = np.logspace(0, 6, 1000) + s = 1j * 2 * np.pi * freq + # Very low gain system + loop_gain = 0.001 / (1 + s / (2 * np.pi * 1000)) + result = compute_phase_margin(freq, loop_gain) + assert result["phase_margin_deg"] == float("inf") + assert result["is_stable"] is True + assert result["gain_crossover_freq_hz"] is None + + def test_short_input(self): + result = compute_phase_margin(np.array([1.0]), np.array([1.0 + 0j])) + assert result["phase_margin_deg"] is None + + +class TestStabilityMetrics: + def test_comprehensive_output(self): + """compute_stability_metrics returns all expected fields.""" + freq = np.logspace(0, 6, 5000) + w1 = 2 * np.pi * 100 + w2 = 2 * np.pi * 1000 + w3 = 2 * np.pi * 10000 + K = 100.0 + s = 1j * 2 * np.pi * freq + loop_gain = K / ((s / w1 + 1) * (s / w2 + 1) * (s / w3 + 1)) + + result = compute_stability_metrics(freq, loop_gain) + assert "gain_margin" in result + assert "phase_margin" in result + assert "bode" in result + assert "is_stable" in result + assert len(result["bode"]["frequency_hz"]) == len(freq) + + def test_short_input_structure(self): + result = compute_stability_metrics(np.array([]), np.array([])) + assert result["is_stable"] is None diff --git a/tests/test_touchstone.py b/tests/test_touchstone.py new file mode 100644 index 0000000..186c860 --- /dev/null +++ b/tests/test_touchstone.py @@ -0,0 +1,179 @@ +"""Tests for touchstone module: format conversion, parsing, S-parameter extraction.""" + +import re +from pathlib import Path + +import numpy as np +import pytest + +from mcp_ltspice.touchstone import ( + TouchstoneData, + _detect_ports, + _to_complex, + get_s_parameter, + parse_touchstone, + s_param_to_db, +) + + +class TestToComplex: + def test_ri_format(self): + """RI: (real, imag) -> complex.""" + c = _to_complex(3.0, 4.0, "RI") + assert c == complex(3.0, 4.0) + + def test_ma_format(self): + """MA: (magnitude, angle_deg) -> complex.""" + c = _to_complex(1.0, 0.0, "MA") + assert c == pytest.approx(complex(1.0, 0.0), abs=1e-10) + + c90 = _to_complex(1.0, 90.0, "MA") + assert c90.real == pytest.approx(0.0, abs=1e-10) + assert c90.imag == pytest.approx(1.0, abs=1e-10) + + def test_db_format(self): + """DB: (mag_db, angle_deg) -> complex.""" + # 0 dB = magnitude 1.0 + c = _to_complex(0.0, 0.0, "DB") + assert abs(c) == pytest.approx(1.0) + + # 20 dB = magnitude 10.0 + c20 = _to_complex(20.0, 0.0, "DB") + assert abs(c20) == pytest.approx(10.0, rel=0.01) + + def test_unknown_format_raises(self): + with pytest.raises(ValueError, match="Unknown format"): + _to_complex(1.0, 0.0, "XY") + + +class TestDetectPorts: + @pytest.mark.parametrize( + "suffix, expected", + [ + (".s1p", 1), + (".s2p", 2), + (".s3p", 3), + (".s4p", 4), + (".S2P", 2), # case insensitive + ], + ) + def test_valid_extensions(self, suffix, expected): + p = Path(f"test{suffix}") + assert _detect_ports(p) == expected + + def test_invalid_extension(self): + with pytest.raises(ValueError, match="Cannot determine port count"): + _detect_ports(Path("test.txt")) + + +class TestParseTouchstone: + def test_parse_s2p(self, tmp_s2p_file): + """Parse a synthetic .s2p file and verify structure.""" + data = parse_touchstone(tmp_s2p_file) + assert data.n_ports == 2 + assert data.parameter_type == "S" + assert data.format_type == "MA" + assert data.reference_impedance == 50.0 + assert len(data.frequencies) == 3 + assert data.data.shape == (3, 2, 2) + + def test_frequencies_in_hz(self, tmp_s2p_file): + """Frequencies should be converted to Hz (from GHz).""" + data = parse_touchstone(tmp_s2p_file) + # First freq is 1.0 GHz = 1e9 Hz + assert data.frequencies[0] == pytest.approx(1e9) + assert data.frequencies[1] == pytest.approx(2e9) + assert data.frequencies[2] == pytest.approx(3e9) + + def test_s11_values(self, tmp_s2p_file): + """S11 at first frequency should match input: mag=0.5, angle=-30.""" + data = parse_touchstone(tmp_s2p_file) + s11 = data.data[0, 0, 0] + assert abs(s11) == pytest.approx(0.5, rel=0.01) + assert np.degrees(np.angle(s11)) == pytest.approx(-30.0, abs=1.0) + + def test_comments_parsed(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + assert len(data.comments) > 0 + + def test_s1p_file(self, tmp_path): + """Parse a minimal .s1p file.""" + content = ( + "# MHZ S RI R 50\n" + "100 0.5 0.3\n" + "200 0.4 0.2\n" + ) + p = tmp_path / "test.s1p" + p.write_text(content) + data = parse_touchstone(p) + assert data.n_ports == 1 + assert data.data.shape == (2, 1, 1) + # 100 MHz = 100e6 Hz + assert data.frequencies[0] == pytest.approx(100e6) + + def test_db_format_file(self, tmp_path): + """Parse a .s1p file in DB format.""" + content = ( + "# GHZ S DB R 50\n" + "1.0 -3.0 -45\n" + "2.0 -6.0 -90\n" + ) + p = tmp_path / "dbtest.s1p" + p.write_text(content) + data = parse_touchstone(p) + assert data.format_type == "DB" + # -3 dB -> magnitude ~ 0.707 + assert abs(data.data[0, 0, 0]) == pytest.approx(10 ** (-3.0 / 20.0), rel=0.01) + + +class TestSParamToDb: + def test_unity_magnitude(self): + """Magnitude 1.0 -> 0 dB.""" + vals = np.array([1.0 + 0j]) + db = s_param_to_db(vals) + assert db[0] == pytest.approx(0.0, abs=0.01) + + def test_known_magnitude(self): + """Magnitude 0.1 -> -20 dB.""" + vals = np.array([0.1 + 0j]) + db = s_param_to_db(vals) + assert db[0] == pytest.approx(-20.0, abs=0.1) + + def test_zero_magnitude(self): + """Zero magnitude should not produce -inf (floored).""" + vals = np.array([0.0 + 0j]) + db = s_param_to_db(vals) + assert np.isfinite(db[0]) + assert db[0] < -200 + + +class TestGetSParameter: + def test_1_based_indexing(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + freqs, vals = get_s_parameter(data, 1, 1) + assert len(freqs) == 3 + assert len(vals) == 3 + + def test_s21(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + # S21 is stored at (row=0, col=1) -> get_s_parameter(data, 1, 2) + # because the parser iterates row then col, and Touchstone 2-port + # order is S11, S21, S12, S22 -> (0,0), (0,1), (1,0), (1,1) + freqs, vals = get_s_parameter(data, 1, 2) + # S21 at first freq: mag=0.9, angle=-10 + assert abs(vals[0]) == pytest.approx(0.9, rel=0.01) + + def test_out_of_range_row(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + with pytest.raises(IndexError, match="Row index"): + get_s_parameter(data, 3, 1) + + def test_out_of_range_col(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + with pytest.raises(IndexError, match="Column index"): + get_s_parameter(data, 1, 3) + + def test_zero_index_raises(self, tmp_s2p_file): + data = parse_touchstone(tmp_s2p_file) + with pytest.raises(IndexError): + get_s_parameter(data, 0, 1) diff --git a/tests/test_waveform_expr.py b/tests/test_waveform_expr.py new file mode 100644 index 0000000..db4fe1f --- /dev/null +++ b/tests/test_waveform_expr.py @@ -0,0 +1,167 @@ +"""Tests for waveform_expr module: tokenizer, parser, expression evaluator.""" + +import numpy as np +import pytest + +from mcp_ltspice.waveform_expr import ( + WaveformCalculator, + _Token, + _tokenize, + _TokenType, + evaluate_expression, +) + + +# --------------------------------------------------------------------------- +# Tokenizer tests +# --------------------------------------------------------------------------- + + +class TestTokenizer: + def test_number_tokens(self): + tokens = _tokenize("42 3.14 1e-3") + nums = [t for t in tokens if t.type == _TokenType.NUMBER] + assert len(nums) == 3 + assert nums[0].value == "42" + assert nums[1].value == "3.14" + assert nums[2].value == "1e-3" + + def test_signal_tokens(self): + tokens = _tokenize("V(out) + I(R1)") + signals = [t for t in tokens if t.type == _TokenType.SIGNAL] + assert len(signals) == 2 + assert signals[0].value == "V(out)" + assert signals[1].value == "I(R1)" + + def test_operator_tokens(self): + tokens = _tokenize("1 + 2 - 3 * 4 / 5") + ops = [t for t in tokens if t.type not in (_TokenType.NUMBER, _TokenType.EOF)] + types = [t.type for t in ops] + assert types == [_TokenType.PLUS, _TokenType.MINUS, _TokenType.STAR, _TokenType.SLASH] + + def test_function_tokens(self): + tokens = _tokenize("abs(V(out))") + funcs = [t for t in tokens if t.type == _TokenType.FUNC] + assert len(funcs) == 1 + assert funcs[0].value == "abs" + + def test_case_insensitive_functions(self): + tokens = _tokenize("dB(V(out))") + funcs = [t for t in tokens if t.type == _TokenType.FUNC] + assert funcs[0].value == "db" + + def test_bare_identifier(self): + tokens = _tokenize("time + 1") + signals = [t for t in tokens if t.type == _TokenType.SIGNAL] + assert len(signals) == 1 + assert signals[0].value == "time" + + def test_invalid_character_raises(self): + with pytest.raises(ValueError, match="Unexpected character"): + _tokenize("V(out) @ 2") + + def test_eof_token(self): + tokens = _tokenize("1") + assert tokens[-1].type == _TokenType.EOF + + +# --------------------------------------------------------------------------- +# Expression evaluator tests (scalar via numpy scalars) +# --------------------------------------------------------------------------- + + +class TestEvaluateExpression: + def test_addition(self): + result = evaluate_expression("2 + 3", {}) + assert float(result) == pytest.approx(5.0) + + def test_multiplication(self): + result = evaluate_expression("4 * 5", {}) + assert float(result) == pytest.approx(20.0) + + def test_precedence(self): + """Multiplication binds tighter than addition: 2+3*4=14.""" + result = evaluate_expression("2 + 3 * 4", {}) + assert float(result) == pytest.approx(14.0) + + def test_unary_minus(self): + result = evaluate_expression("-5 + 3", {}) + assert float(result) == pytest.approx(-2.0) + + def test_nested_parens(self): + result = evaluate_expression("(2 + 3) * (4 - 1)", {}) + assert float(result) == pytest.approx(15.0) + + def test_division_by_near_zero(self): + """Division by near-zero uses a safe floor to avoid inf.""" + result = evaluate_expression("1 / 0", {}) + # Should return a very large number, not inf + assert np.isfinite(result) + + def test_db_function(self): + """dB(x) = 20 * log10(|x|).""" + result = evaluate_expression("db(10)", {}) + assert float(result) == pytest.approx(20.0, rel=0.01) + + def test_abs_function(self): + result = evaluate_expression("abs(-7)", {}) + assert float(result) == pytest.approx(7.0) + + def test_sqrt_function(self): + result = evaluate_expression("sqrt(16)", {}) + assert float(result) == pytest.approx(4.0) + + def test_log10_function(self): + result = evaluate_expression("log10(1000)", {}) + assert float(result) == pytest.approx(3.0) + + def test_signal_lookup(self): + """Expression referencing a variable by name.""" + variables = {"V(out)": np.array([1.0, 2.0, 3.0])} + result = evaluate_expression("V(out) * 2", variables) + np.testing.assert_array_almost_equal(result, [2.0, 4.0, 6.0]) + + def test_unknown_signal_raises(self): + with pytest.raises(ValueError, match="Unknown signal"): + evaluate_expression("V(missing)", {"V(out)": np.array([1.0])}) + + def test_unknown_function_raises(self): + # 'sin' is not in the supported function set -- the tokenizer treats + # it as a signal name "sin(1)", so the error is "Unknown signal" + with pytest.raises(ValueError, match="Unknown signal"): + evaluate_expression("sin(1)", {}) + + def test_malformed_expression(self): + with pytest.raises(ValueError): + evaluate_expression("2 +", {}) + + def test_case_insensitive_signal(self): + """Signal lookup is case-insensitive.""" + variables = {"V(OUT)": np.array([10.0])} + result = evaluate_expression("V(out)", variables) + np.testing.assert_array_almost_equal(result, [10.0]) + + +# --------------------------------------------------------------------------- +# WaveformCalculator tests +# --------------------------------------------------------------------------- + + +class TestWaveformCalculator: + def test_calc_available_signals(self, mock_rawfile): + calc = WaveformCalculator(mock_rawfile) + signals = calc.available_signals() + assert "time" in signals + assert "V(out)" in signals + + def test_calc_expression(self, mock_rawfile): + calc = WaveformCalculator(mock_rawfile) + result = calc.calc("V(out) * 2") + expected = np.real(mock_rawfile.data[1]) * 2 + np.testing.assert_array_almost_equal(result, expected) + + def test_calc_db(self, mock_rawfile): + calc = WaveformCalculator(mock_rawfile) + result = calc.calc("db(V(out))") + # db should produce real values + assert np.all(np.isfinite(result)) diff --git a/tests/test_waveform_math.py b/tests/test_waveform_math.py new file mode 100644 index 0000000..c021999 --- /dev/null +++ b/tests/test_waveform_math.py @@ -0,0 +1,184 @@ +"""Tests for waveform_math module: RMS, peak-to-peak, FFT, THD, bandwidth, settling, rise time.""" + +import numpy as np +import pytest + +from mcp_ltspice.waveform_math import ( + compute_bandwidth, + compute_fft, + compute_peak_to_peak, + compute_rise_time, + compute_rms, + compute_settling_time, + compute_thd, +) + + +class TestComputeRms: + def test_dc_signal_exact(self, dc_signal): + """RMS of a DC signal equals its DC value.""" + rms = compute_rms(dc_signal) + assert rms == pytest.approx(3.3, abs=1e-10) + + def test_sine_vpk_over_sqrt2(self, sine_1khz): + """RMS of a pure sine (1 V peak) should be 1/sqrt(2).""" + rms = compute_rms(sine_1khz) + assert rms == pytest.approx(1.0 / np.sqrt(2), rel=0.01) + + def test_empty_signal(self): + """RMS of an empty array is 0.""" + assert compute_rms(np.array([])) == 0.0 + + def test_single_sample(self): + """RMS of a single sample equals abs(sample).""" + assert compute_rms(np.array([5.0])) == pytest.approx(5.0) + + def test_complex_signal(self): + """RMS uses only the real part of complex data.""" + sig = np.array([3.0 + 4j, 3.0 + 4j]) + rms = compute_rms(sig) + assert rms == pytest.approx(3.0) + + +class TestComputePeakToPeak: + def test_sine_wave(self, sine_1khz): + """Peak-to-peak of a 1 V peak sine should be ~2 V.""" + result = compute_peak_to_peak(sine_1khz) + assert result["peak_to_peak"] == pytest.approx(2.0, rel=0.01) + assert result["max"] == pytest.approx(1.0, rel=0.01) + assert result["min"] == pytest.approx(-1.0, rel=0.01) + assert result["mean"] == pytest.approx(0.0, abs=0.01) + + def test_dc_signal(self, dc_signal): + """Peak-to-peak of a DC signal is 0.""" + result = compute_peak_to_peak(dc_signal) + assert result["peak_to_peak"] == pytest.approx(0.0) + + def test_empty_signal(self): + result = compute_peak_to_peak(np.array([])) + assert result["peak_to_peak"] == 0.0 + + +class TestComputeFft: + def test_known_sine_peak_at_correct_freq(self, time_array, sine_1khz): + """A 1 kHz sine should produce a dominant peak at 1 kHz.""" + result = compute_fft(time_array, sine_1khz) + assert result["fundamental_freq"] == pytest.approx(1000, rel=0.05) + assert result["dc_offset"] == pytest.approx(0.0, abs=0.01) + + def test_dc_offset_detection(self, time_array): + """A signal with DC offset should report correct dc_offset.""" + offset = 2.5 + sig = offset + np.sin(2 * np.pi * 1000 * time_array) + result = compute_fft(time_array, sig) + assert result["dc_offset"] == pytest.approx(offset, rel=0.05) + + def test_short_signal(self): + """Very short signals return empty results.""" + result = compute_fft(np.array([0.0]), np.array([1.0])) + assert result["frequencies"] == [] + assert result["fundamental_freq"] == 0.0 + + def test_zero_dt(self): + """Time array with zero duration returns gracefully.""" + result = compute_fft(np.array([1.0, 1.0]), np.array([1.0, 2.0])) + assert result["frequencies"] == [] + + +class TestComputeThd: + def test_pure_sine_low_thd(self, time_array, sine_1khz): + """A pure sine wave should have very low THD.""" + result = compute_thd(time_array, sine_1khz) + assert result["thd_percent"] < 1.0 + assert result["fundamental_freq"] == pytest.approx(1000, rel=0.05) + + def test_clipped_sine_high_thd(self, time_array, sine_1khz): + """A hard-clipped sine should have significantly higher THD.""" + clipped = np.clip(sine_1khz, -0.5, 0.5) + result = compute_thd(time_array, clipped) + # Clipping at 50% introduces substantial harmonics + assert result["thd_percent"] > 10.0 + + def test_short_signal(self): + result = compute_thd(np.array([0.0]), np.array([1.0])) + assert result["thd_percent"] == 0.0 + + +class TestComputeBandwidth: + def test_lowpass_cutoff(self, ac_frequency, lowpass_response): + """Lowpass with fc=1kHz should report bandwidth near 1 kHz.""" + result = compute_bandwidth(ac_frequency, lowpass_response) + assert result["bandwidth_hz"] == pytest.approx(1000, rel=0.1) + assert result["type"] == "lowpass" + + def test_all_above_cutoff(self, ac_frequency): + """If all magnitudes are above -3dB level, bandwidth spans entire range.""" + flat = np.zeros_like(ac_frequency) + result = compute_bandwidth(ac_frequency, flat) + assert result["bandwidth_hz"] > 0 + + def test_short_input(self): + result = compute_bandwidth(np.array([1.0]), np.array([0.0])) + assert result["bandwidth_hz"] == 0.0 + + def test_bandpass_shape(self): + """A peaked response should be detected as bandpass.""" + fc = 10_000.0 + Q = 5.0 # Q factor => BW = fc/Q = 2000 Hz + bw_expected = fc / Q + freq = np.logspace(2, 6, 2000) + # Second-order bandpass: H(s) = (s/wn/Q) / (s^2/wn^2 + s/wn/Q + 1) + wn = 2 * np.pi * fc + s = 1j * 2 * np.pi * freq + H = (s / wn / Q) / (s**2 / wn**2 + s / wn / Q + 1) + mag_db = 20.0 * np.log10(np.abs(H)) + result = compute_bandwidth(freq, mag_db) + assert result["type"] == "bandpass" + assert result["bandwidth_hz"] == pytest.approx(bw_expected, rel=0.15) + + +class TestComputeSettlingTime: + def test_already_settled(self, time_array, dc_signal): + """A constant signal is already settled at t=0.""" + t = np.linspace(0, 0.01, len(dc_signal)) + result = compute_settling_time(t, dc_signal, final_value=3.3) + assert result["settled"] is True + assert result["settling_time"] == 0.0 + + def test_step_response(self, time_array, step_signal): + """Step response should settle after the transient.""" + result = compute_settling_time(time_array, step_signal, final_value=1.0) + assert result["settled"] is True + assert result["settling_time"] > 0 + + def test_never_settles(self, time_array, sine_1khz): + """An oscillating signal never settles to a DC value.""" + result = compute_settling_time(time_array, sine_1khz, final_value=0.5) + assert result["settled"] is False + + def test_short_signal(self): + result = compute_settling_time(np.array([0.0]), np.array([1.0])) + assert result["settled"] is False + + +class TestComputeRiseTime: + def test_fast_step(self): + """A fast rising step should have a short rise time.""" + t = np.linspace(0, 1e-3, 10000) + # Step with very fast exponential rise + sig = np.where(t > 0.1e-3, 1.0 - np.exp(-(t - 0.1e-3) / 20e-6), 0.0) + result = compute_rise_time(t, sig) + assert result["rise_time"] > 0 + # 10-90% rise time of RC = ~2.2 * tau + assert result["rise_time"] == pytest.approx(2.2 * 20e-6, rel=0.2) + + def test_no_swing(self): + """Flat signal has zero rise time.""" + t = np.linspace(0, 1, 100) + sig = np.ones(100) * 5.0 + result = compute_rise_time(t, sig) + assert result["rise_time"] == 0.0 + + def test_short_signal(self): + result = compute_rise_time(np.array([0.0]), np.array([0.0])) + assert result["rise_time"] == 0.0 diff --git a/uv.lock b/uv.lock index aa7d71e..e3070ef 100644 --- a/uv.lock +++ b/uv.lock @@ -1063,6 +1063,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] plot = [ @@ -1075,6 +1076,7 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'plot'", specifier = ">=3.7.0" }, { name = "numpy", specifier = ">=1.24.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, ] provides-extras = ["dev", "plot"] @@ -1602,6 +1604,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"