"""Tests for waveform_expr module: tokenizer, parser, expression evaluator.""" import numpy as np import pytest from mcp_ltspice.waveform_expr import ( WaveformCalculator, _tokenize, _TokenType, evaluate_expression, ) # --------------------------------------------------------------------------- # Tokenizer tests # --------------------------------------------------------------------------- class TestTokenizer: def test_number_tokens(self): tokens = _tokenize("42 3.14 1e-3") nums = [t for t in tokens if t.type == _TokenType.NUMBER] assert len(nums) == 3 assert nums[0].value == "42" assert nums[1].value == "3.14" assert nums[2].value == "1e-3" def test_signal_tokens(self): tokens = _tokenize("V(out) + I(R1)") signals = [t for t in tokens if t.type == _TokenType.SIGNAL] assert len(signals) == 2 assert signals[0].value == "V(out)" assert signals[1].value == "I(R1)" def test_operator_tokens(self): tokens = _tokenize("1 + 2 - 3 * 4 / 5") ops = [t for t in tokens if t.type not in (_TokenType.NUMBER, _TokenType.EOF)] types = [t.type for t in ops] assert types == [_TokenType.PLUS, _TokenType.MINUS, _TokenType.STAR, _TokenType.SLASH] def test_function_tokens(self): tokens = _tokenize("abs(V(out))") funcs = [t for t in tokens if t.type == _TokenType.FUNC] assert len(funcs) == 1 assert funcs[0].value == "abs" def test_case_insensitive_functions(self): tokens = _tokenize("dB(V(out))") funcs = [t for t in tokens if t.type == _TokenType.FUNC] assert funcs[0].value == "db" def test_bare_identifier(self): tokens = _tokenize("time + 1") signals = [t for t in tokens if t.type == _TokenType.SIGNAL] assert len(signals) == 1 assert signals[0].value == "time" def test_invalid_character_raises(self): with pytest.raises(ValueError, match="Unexpected character"): _tokenize("V(out) @ 2") def test_eof_token(self): tokens = _tokenize("1") assert tokens[-1].type == _TokenType.EOF # --------------------------------------------------------------------------- # Expression evaluator tests (scalar via numpy scalars) # --------------------------------------------------------------------------- class TestEvaluateExpression: def test_addition(self): result = evaluate_expression("2 + 3", {}) assert float(result) == pytest.approx(5.0) def test_multiplication(self): result = evaluate_expression("4 * 5", {}) assert float(result) == pytest.approx(20.0) def test_precedence(self): """Multiplication binds tighter than addition: 2+3*4=14.""" result = evaluate_expression("2 + 3 * 4", {}) assert float(result) == pytest.approx(14.0) def test_unary_minus(self): result = evaluate_expression("-5 + 3", {}) assert float(result) == pytest.approx(-2.0) def test_nested_parens(self): result = evaluate_expression("(2 + 3) * (4 - 1)", {}) assert float(result) == pytest.approx(15.0) def test_division_by_near_zero(self): """Division by near-zero uses a safe floor to avoid inf.""" result = evaluate_expression("1 / 0", {}) # Should return a very large number, not inf assert np.isfinite(result) def test_db_function(self): """dB(x) = 20 * log10(|x|).""" result = evaluate_expression("db(10)", {}) assert float(result) == pytest.approx(20.0, rel=0.01) def test_abs_function(self): result = evaluate_expression("abs(-7)", {}) assert float(result) == pytest.approx(7.0) def test_sqrt_function(self): result = evaluate_expression("sqrt(16)", {}) assert float(result) == pytest.approx(4.0) def test_log10_function(self): result = evaluate_expression("log10(1000)", {}) assert float(result) == pytest.approx(3.0) def test_signal_lookup(self): """Expression referencing a variable by name.""" variables = {"V(out)": np.array([1.0, 2.0, 3.0])} result = evaluate_expression("V(out) * 2", variables) np.testing.assert_array_almost_equal(result, [2.0, 4.0, 6.0]) def test_unknown_signal_raises(self): with pytest.raises(ValueError, match="Unknown signal"): evaluate_expression("V(missing)", {"V(out)": np.array([1.0])}) def test_unknown_function_raises(self): # 'sin' is not in the supported function set -- the tokenizer treats # it as a signal name "sin(1)", so the error is "Unknown signal" with pytest.raises(ValueError, match="Unknown signal"): evaluate_expression("sin(1)", {}) def test_malformed_expression(self): with pytest.raises(ValueError): evaluate_expression("2 +", {}) def test_case_insensitive_signal(self): """Signal lookup is case-insensitive.""" variables = {"V(OUT)": np.array([10.0])} result = evaluate_expression("V(out)", variables) np.testing.assert_array_almost_equal(result, [10.0]) # --------------------------------------------------------------------------- # WaveformCalculator tests # --------------------------------------------------------------------------- class TestWaveformCalculator: def test_calc_available_signals(self, mock_rawfile): calc = WaveformCalculator(mock_rawfile) signals = calc.available_signals() assert "time" in signals assert "V(out)" in signals def test_calc_expression(self, mock_rawfile): calc = WaveformCalculator(mock_rawfile) result = calc.calc("V(out) * 2") expected = np.real(mock_rawfile.data[1]) * 2 np.testing.assert_array_almost_equal(result, expected) def test_calc_db(self, mock_rawfile): calc = WaveformCalculator(mock_rawfile) result = calc.calc("db(V(out))") # db should produce real values assert np.all(np.isfinite(result))