Skip to content

Commit 89a2274

Browse files
Add expression component (#143)
* expression component and tests * Add check of incorrect function * 100% test coverage?
1 parent e072df7 commit 89a2274

File tree

7 files changed

+1068
-512
lines changed

7 files changed

+1068
-512
lines changed

docs/docs/tutorials/components.ipynb

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"from easydynamics.sample_model import DampedHarmonicOscillator\n",
2727
"from easydynamics.sample_model import DeltaFunction\n",
2828
"from easydynamics.sample_model import Exponential\n",
29+
"from easydynamics.sample_model import ExpressionComponent\n",
2930
"from easydynamics.sample_model import Gaussian\n",
3031
"from easydynamics.sample_model import Lorentzian\n",
3132
"from easydynamics.sample_model import Polynomial\n",
@@ -123,11 +124,37 @@
123124
"plt.legend()\n",
124125
"plt.show()"
125126
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"id": "9a113170",
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
135+
"expr = ExpressionComponent(\n",
136+
" 'A * exp(-(x - x0)**2 / (2*sigma**2)) +B*sin(2*pi*x/period)',\n",
137+
" parameters={'A': 10, 'x0': 0, 'sigma': 1},\n",
138+
")\n",
139+
"\n",
140+
"expr.A = 5\n",
141+
"expr.sigma = 0.5\n",
142+
"\n",
143+
"expr.period = 2.0\n",
144+
"\n",
145+
"x = np.linspace(-5, 5, 100)\n",
146+
"y = expr.evaluate(x)\n",
147+
"\n",
148+
"plt.figure()\n",
149+
"plt.plot(x, y, label='Expression Component')\n",
150+
"plt.legend()\n",
151+
"plt.show()"
152+
]
126153
}
127154
],
128155
"metadata": {
129156
"kernelspec": {
130-
"display_name": "easydynamics_newbase",
157+
"display_name": "Python 3",
131158
"language": "python",
132159
"name": "python3"
133160
},
@@ -141,7 +168,7 @@
141168
"name": "python",
142169
"nbconvert_exporter": "python",
143170
"pygments_lexer": "ipython3",
144-
"version": "3.12.12"
171+
"version": "3.12.13"
145172
}
146173
},
147174
"nbformat": 4,

pixi.lock

Lines changed: 556 additions & 510 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
'ipywidgets', # Widgets (needed for interactive matplotlib backends)
3333
'ipympl', # Matplotlib Jupyter widget backend (%matplotlib widget)
3434
'IPython', # Interactive Python shell
35+
'sympy', # Symbolic mathematics (used for expression components)
3536
]
3637

3738
[project.optional-dependencies]

src/easydynamics/sample_model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .components import DampedHarmonicOscillator
77
from .components import DeltaFunction
88
from .components import Exponential
9+
from .components import ExpressionComponent
910
from .components import Gaussian
1011
from .components import Lorentzian
1112
from .components import Polynomial
@@ -29,4 +30,5 @@
2930
'ResolutionModel',
3031
'BackgroundModel',
3132
'InstrumentModel',
33+
'ExpressionComponent',
3234
]

src/easydynamics/sample_model/components/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .damped_harmonic_oscillator import DampedHarmonicOscillator
55
from .delta_function import DeltaFunction
66
from .exponential import Exponential
7+
from .expression_component import ExpressionComponent
78
from .gaussian import Gaussian
89
from .lorentzian import Lorentzian
910
from .polynomial import Polynomial
@@ -17,4 +18,5 @@
1718
'DampedHarmonicOscillator',
1819
'Polynomial',
1920
'Exponential',
21+
'ExpressionComponent',
2022
]
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# SPDX-FileCopyrightText: 2026 EasyScience contributors
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from __future__ import annotations
5+
6+
import numpy as np
7+
import scipp as sc
8+
import sympy as sp
9+
from easyscience.variable import Parameter
10+
11+
from easydynamics.utils.utils import Numeric
12+
13+
from .model_component import ModelComponent
14+
15+
16+
class ExpressionComponent(ModelComponent):
17+
"""Model component defined by a symbolic expression.
18+
19+
Example:
20+
expr = ExpressionComponent(
21+
"A * exp(-(x - x0)**2 / (2*sigma**2))",
22+
parameters={"A": 10, "x0": 0, "sigma": 1},
23+
)
24+
25+
expr.A = 5
26+
y = expr.evaluate(x)
27+
"""
28+
29+
# -------------------------
30+
# Allowed symbolic functions
31+
# -------------------------
32+
_ALLOWED_FUNCS = {
33+
# Exponentials & logs
34+
'exp': sp.exp,
35+
'log': sp.log,
36+
'ln': sp.log,
37+
'sqrt': sp.sqrt,
38+
# Trigonometric
39+
'sin': sp.sin,
40+
'cos': sp.cos,
41+
'tan': sp.tan,
42+
'sinc': sp.sinc,
43+
'cot': sp.cot,
44+
'sec': sp.sec,
45+
'csc': sp.csc,
46+
'asin': sp.asin,
47+
'acos': sp.acos,
48+
'atan': sp.atan,
49+
# Hyperbolic
50+
'sinh': sp.sinh,
51+
'cosh': sp.cosh,
52+
'tanh': sp.tanh,
53+
# Misc
54+
'abs': sp.Abs,
55+
'sign': sp.sign,
56+
'floor': sp.floor,
57+
'ceil': sp.ceiling,
58+
# Special functions
59+
'erf': sp.erf,
60+
}
61+
62+
# -------------------------
63+
# Allowed constants
64+
# -------------------------
65+
_ALLOWED_CONSTANTS = {
66+
'pi': sp.pi,
67+
'E': sp.E,
68+
}
69+
70+
_RESERVED_NAMES = {'x'}
71+
72+
def __init__(
73+
self,
74+
expression: str,
75+
parameters: dict[str, Numeric] | None = None,
76+
unit: str | sc.Unit = 'meV',
77+
display_name: str | None = 'Expression',
78+
unique_name: str | None = None,
79+
) -> None:
80+
"""Initialize the ExpressionComponent.
81+
82+
Args:
83+
expression (str): The symbolic expression as a string.
84+
Must contain 'x' as the independent variable.
85+
parameters (dict[str, Numeric] | None, default=None):
86+
Dictionary of parameter names and their initial values.
87+
Defaults to None (no parameters).
88+
unit (str | sc.Unit, default="meV"): Unit of the output.
89+
display_name (str | None, default="Expression"): Display name for the component.
90+
unique_name (str | None, default=None): Unique name for the component.
91+
92+
Raises:
93+
ValueError: If the expression is invalid or does not contain 'x'.
94+
TypeError: If any parameter value is not numeric.
95+
"""
96+
super().__init__(unit=unit, display_name=display_name, unique_name=unique_name)
97+
98+
if 'np.' in expression:
99+
raise ValueError(
100+
'NumPy syntax (np.*) is not supported. '
101+
"Use functions like 'exp', 'sin', etc. directly."
102+
)
103+
104+
self._expression_str = expression
105+
106+
locals_dict = {}
107+
locals_dict.update(self._ALLOWED_FUNCS)
108+
locals_dict.update(self._ALLOWED_CONSTANTS)
109+
110+
try:
111+
self._expr = sp.sympify(expression, locals=locals_dict)
112+
except Exception as e:
113+
raise ValueError(f'Invalid expression: {expression}') from e
114+
115+
# Extract symbols from the expression
116+
symbols = self._expr.free_symbols
117+
symbol_names = sorted(str(s) for s in symbols)
118+
119+
if 'x' not in symbol_names:
120+
raise ValueError("Expression must contain 'x' as independent variable")
121+
122+
# Reject unknown functions early so invalid expressions fail at init,
123+
# not later during numerical evaluation.
124+
allowed_function_names = set(self._ALLOWED_FUNCS) | {
125+
func.__name__ for func in self._ALLOWED_FUNCS.values()
126+
}
127+
128+
# Walk all function-call nodes in the parsed expression (e.g. sin(x), foo(x)).
129+
# Keep only function names that are not in our allowlist.
130+
unknown_function_names: set[str] = set()
131+
function_atoms = self._expr.atoms(sp.Function)
132+
for function_atom in function_atoms:
133+
function_name = function_atom.func.__name__
134+
if function_name not in allowed_function_names:
135+
unknown_function_names.add(function_name)
136+
137+
unknown_functions = sorted(unknown_function_names)
138+
139+
if unknown_functions:
140+
raise ValueError(
141+
f'Unsupported function(s) in expression: {", ".join(unknown_functions)}'
142+
)
143+
144+
# Create parameters
145+
if parameters is not None and not isinstance(parameters, dict):
146+
raise TypeError(
147+
f'Parameters must be None or a dictionary, got {type(parameters).__name__}'
148+
)
149+
150+
if parameters is not None:
151+
for name, value in parameters.items():
152+
if not isinstance(value, Numeric):
153+
raise TypeError(f"Parameter '{name}' must be numeric")
154+
parameters = parameters or {}
155+
self._parameters: dict[str, Parameter] = {}
156+
157+
self._symbol_names = symbol_names
158+
for name in self._symbol_names:
159+
if name in self._RESERVED_NAMES:
160+
continue
161+
162+
value = parameters.get(name, 1.0)
163+
164+
self._parameters[name] = Parameter(
165+
name=name,
166+
value=value,
167+
unit=self._unit,
168+
)
169+
170+
# Create numerical function
171+
ordered_symbols = [sp.Symbol(name) for name in self._symbol_names]
172+
173+
self._func = sp.lambdify(
174+
ordered_symbols,
175+
self._expr,
176+
modules=['numpy'],
177+
)
178+
179+
# -------------------------
180+
# Properties
181+
# -------------------------
182+
183+
@property
184+
def expression(self) -> str:
185+
"""Return the original expression string."""
186+
return self._expression_str
187+
188+
@expression.setter
189+
def expression(self, _new_expr: str) -> None:
190+
"""Prevent changing the expression after initialization.
191+
192+
Args:
193+
_new_expr (str): New expression string (ignored).
194+
195+
Raises:
196+
AttributeError: Always raised to prevent changing the expression.
197+
"""
198+
raise AttributeError('Expression cannot be changed after initialization')
199+
200+
def evaluate(
201+
self,
202+
x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
203+
) -> np.ndarray:
204+
"""Evaluate the expression for given x values.
205+
206+
Args:
207+
x (Numeric | list | np.ndarray | sc.Variable | sc.DataArray):
208+
Input values for the independent variable.
209+
210+
Returns:
211+
np.ndarray: Evaluated results.
212+
"""
213+
x = self._prepare_x_for_evaluate(x)
214+
215+
args = []
216+
for name in self._symbol_names:
217+
if name == 'x':
218+
args.append(x)
219+
else:
220+
args.append(self._parameters[name].value)
221+
222+
return self._func(*args)
223+
224+
def get_all_variables(self) -> list[Parameter]:
225+
"""Return all parameters.
226+
227+
Returns:
228+
list[Parameter]: List of all parameters in the expression.
229+
"""
230+
return list(self._parameters.values())
231+
232+
def convert_unit(self, _new_unit: str | sc.Unit) -> None:
233+
"""Convert the unit of the expression.
234+
235+
Unit conversion is not implemented for ExpressionComponent.
236+
237+
Args:
238+
_new_unit (str | sc.Unit): The new unit to convert to (ignored).
239+
240+
Raises:
241+
NotImplementedError: Always raised to indicate unit conversion is not supported.
242+
"""
243+
244+
raise NotImplementedError('Unit conversion is not implemented for ExpressionComponent')
245+
246+
# -------------------------
247+
# dunder methods
248+
# -------------------------
249+
250+
def __getattr__(self, name: str) -> Parameter:
251+
"""Allow access to parameters as attributes.
252+
253+
Args:
254+
name (str): Name of the parameter to access.
255+
256+
Returns:
257+
Parameter: The parameter with the given name.
258+
259+
Raises:
260+
AttributeError: If the parameter does not exist.
261+
"""
262+
if '_parameters' in self.__dict__ and name in self._parameters:
263+
return self._parameters[name]
264+
raise AttributeError(f"{self.__class__.__name__} has no attribute '{name}'")
265+
266+
def __setattr__(self, name: str, value: Numeric) -> None:
267+
"""Allow setting parameter values as attributes.
268+
269+
Args:
270+
name (str): Name of the parameter to set.
271+
value (Numeric): New value for the parameter.
272+
273+
Raises:
274+
TypeError: If the value is not numeric.
275+
"""
276+
if '_parameters' in self.__dict__ and name in self._parameters:
277+
param = self._parameters[name]
278+
279+
if not isinstance(value, Numeric):
280+
raise TypeError(f'{name} must be numeric')
281+
282+
param.value = value
283+
else:
284+
# For other attributes, use default behavior
285+
super().__setattr__(name, value)
286+
287+
def __dir__(self) -> list[str]:
288+
"""Include parameter names in dir() output for better IDE
289+
support.
290+
291+
Returns:
292+
list[str]: List of attribute names, including parameters.
293+
"""
294+
return super().__dir__() + list(self._parameters.keys())
295+
296+
def __repr__(self) -> str:
297+
param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items())
298+
return (
299+
f'{self.__class__.__name__}(\n'
300+
f" expr='{self._expression_str}',\n"
301+
f' unit={self._unit},\n'
302+
f' parameters={{ {param_str} }}\n'
303+
f')'
304+
)

0 commit comments

Comments
 (0)