mcltspice/tests/test_waveform_expr.py
2026-02-10 23:35:53 -07:00

168 lines
6.0 KiB
Python

"""Tests for waveform_expr module: tokenizer, parser, expression evaluator."""
import numpy as np
import pytest
from mcp_ltspice.waveform_expr import (
WaveformCalculator,
_Token,
_tokenize,
_TokenType,
evaluate_expression,
)
# ---------------------------------------------------------------------------
# Tokenizer tests
# ---------------------------------------------------------------------------
class TestTokenizer:
def test_number_tokens(self):
tokens = _tokenize("42 3.14 1e-3")
nums = [t for t in tokens if t.type == _TokenType.NUMBER]
assert len(nums) == 3
assert nums[0].value == "42"
assert nums[1].value == "3.14"
assert nums[2].value == "1e-3"
def test_signal_tokens(self):
tokens = _tokenize("V(out) + I(R1)")
signals = [t for t in tokens if t.type == _TokenType.SIGNAL]
assert len(signals) == 2
assert signals[0].value == "V(out)"
assert signals[1].value == "I(R1)"
def test_operator_tokens(self):
tokens = _tokenize("1 + 2 - 3 * 4 / 5")
ops = [t for t in tokens if t.type not in (_TokenType.NUMBER, _TokenType.EOF)]
types = [t.type for t in ops]
assert types == [_TokenType.PLUS, _TokenType.MINUS, _TokenType.STAR, _TokenType.SLASH]
def test_function_tokens(self):
tokens = _tokenize("abs(V(out))")
funcs = [t for t in tokens if t.type == _TokenType.FUNC]
assert len(funcs) == 1
assert funcs[0].value == "abs"
def test_case_insensitive_functions(self):
tokens = _tokenize("dB(V(out))")
funcs = [t for t in tokens if t.type == _TokenType.FUNC]
assert funcs[0].value == "db"
def test_bare_identifier(self):
tokens = _tokenize("time + 1")
signals = [t for t in tokens if t.type == _TokenType.SIGNAL]
assert len(signals) == 1
assert signals[0].value == "time"
def test_invalid_character_raises(self):
with pytest.raises(ValueError, match="Unexpected character"):
_tokenize("V(out) @ 2")
def test_eof_token(self):
tokens = _tokenize("1")
assert tokens[-1].type == _TokenType.EOF
# ---------------------------------------------------------------------------
# Expression evaluator tests (scalar via numpy scalars)
# ---------------------------------------------------------------------------
class TestEvaluateExpression:
def test_addition(self):
result = evaluate_expression("2 + 3", {})
assert float(result) == pytest.approx(5.0)
def test_multiplication(self):
result = evaluate_expression("4 * 5", {})
assert float(result) == pytest.approx(20.0)
def test_precedence(self):
"""Multiplication binds tighter than addition: 2+3*4=14."""
result = evaluate_expression("2 + 3 * 4", {})
assert float(result) == pytest.approx(14.0)
def test_unary_minus(self):
result = evaluate_expression("-5 + 3", {})
assert float(result) == pytest.approx(-2.0)
def test_nested_parens(self):
result = evaluate_expression("(2 + 3) * (4 - 1)", {})
assert float(result) == pytest.approx(15.0)
def test_division_by_near_zero(self):
"""Division by near-zero uses a safe floor to avoid inf."""
result = evaluate_expression("1 / 0", {})
# Should return a very large number, not inf
assert np.isfinite(result)
def test_db_function(self):
"""dB(x) = 20 * log10(|x|)."""
result = evaluate_expression("db(10)", {})
assert float(result) == pytest.approx(20.0, rel=0.01)
def test_abs_function(self):
result = evaluate_expression("abs(-7)", {})
assert float(result) == pytest.approx(7.0)
def test_sqrt_function(self):
result = evaluate_expression("sqrt(16)", {})
assert float(result) == pytest.approx(4.0)
def test_log10_function(self):
result = evaluate_expression("log10(1000)", {})
assert float(result) == pytest.approx(3.0)
def test_signal_lookup(self):
"""Expression referencing a variable by name."""
variables = {"V(out)": np.array([1.0, 2.0, 3.0])}
result = evaluate_expression("V(out) * 2", variables)
np.testing.assert_array_almost_equal(result, [2.0, 4.0, 6.0])
def test_unknown_signal_raises(self):
with pytest.raises(ValueError, match="Unknown signal"):
evaluate_expression("V(missing)", {"V(out)": np.array([1.0])})
def test_unknown_function_raises(self):
# 'sin' is not in the supported function set -- the tokenizer treats
# it as a signal name "sin(1)", so the error is "Unknown signal"
with pytest.raises(ValueError, match="Unknown signal"):
evaluate_expression("sin(1)", {})
def test_malformed_expression(self):
with pytest.raises(ValueError):
evaluate_expression("2 +", {})
def test_case_insensitive_signal(self):
"""Signal lookup is case-insensitive."""
variables = {"V(OUT)": np.array([10.0])}
result = evaluate_expression("V(out)", variables)
np.testing.assert_array_almost_equal(result, [10.0])
# ---------------------------------------------------------------------------
# WaveformCalculator tests
# ---------------------------------------------------------------------------
class TestWaveformCalculator:
def test_calc_available_signals(self, mock_rawfile):
calc = WaveformCalculator(mock_rawfile)
signals = calc.available_signals()
assert "time" in signals
assert "V(out)" in signals
def test_calc_expression(self, mock_rawfile):
calc = WaveformCalculator(mock_rawfile)
result = calc.calc("V(out) * 2")
expected = np.real(mock_rawfile.data[1]) * 2
np.testing.assert_array_almost_equal(result, expected)
def test_calc_db(self, mock_rawfile):
calc = WaveformCalculator(mock_rawfile)
result = calc.calc("db(V(out))")
# db should produce real values
assert np.all(np.isfinite(result))