Skip to content

Commit efd2e9f

Browse files
committed
Get rid of model_dict.
1 parent 5012a8b commit efd2e9f

29 files changed

Lines changed: 207 additions & 171 deletions

docs/how_to_guides/model_specs.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Model Specifications
22

3-
Models can be specified using Python dataclasses or dictionaries. The dataclass approach
4-
is recommended for type safety and IDE support.
3+
Models are specified using Python dataclasses. YAML files can be loaded and converted to
4+
`ModelSpec` via `from_dict()`.
55

66
## Using Dataclasses (Recommended)
77

@@ -41,19 +41,22 @@ model = ModelSpec(
4141
)
4242
```
4343

44-
## Using Dictionaries
44+
## Loading from YAML
4545

46-
For backwards compatibility and interoperability with YAML/JSON files, models can also
47-
be specified as dictionaries:
46+
Models can be loaded from YAML files by parsing the YAML and converting to a `ModelSpec`
47+
via `from_dict()`:
4848

4949
```python
5050
import yaml
51+
from skillmodels import ModelSpec
5152

5253
with open("model.yaml") as f:
53-
model = yaml.safe_load(f)
54+
raw = yaml.safe_load(f)
55+
56+
model = ModelSpec.from_dict(raw)
5457
```
5558

56-
The dictionary structure mirrors the dataclass structure:
59+
The YAML dictionary structure mirrors the dataclass structure:
5760

5861
```yaml
5962
factors:

pixi.lock

Lines changed: 2 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ classifiers = [
2020
]
2121
dependencies = [
2222
"dags",
23-
"frozendict",
2423
"jax>=0.8",
2524
"numpy",
2625
"pandas",

src/skillmodels/check_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Functions to validate model specifications."""
22

3+
from collections.abc import Mapping
4+
35
import numpy as np
46

57
from skillmodels.types import Anchoring, Dimensions, Labels
@@ -92,8 +94,8 @@ def _check_anchoring(anchoring: Anchoring) -> list[str]:
9294
report = []
9395
if not isinstance(anchoring.anchoring, bool):
9496
report.append("anchoring.anchoring must be a bool.")
95-
if not isinstance(anchoring.outcomes, dict):
96-
report.append("anchoring.outcomes must be a dict")
97+
if not isinstance(anchoring.outcomes, Mapping):
98+
report.append("anchoring.outcomes must be a Mapping")
9799
else:
98100
variables = list(anchoring.outcomes.values())
99101
for var in variables:

src/skillmodels/constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _get_anchoring_constraints(
389389
ind_tups = []
390390
for period in periods:
391391
for factor in anchoring_info.factors:
392-
outcome = anchoring_info.outcomes[factor] # ty: ignore[invalid-argument-type]
392+
outcome = anchoring_info.outcomes[factor]
393393
meas = f"{outcome}_{factor}"
394394
ind_tups.append(("loadings", period, meas, factor))
395395

@@ -431,7 +431,7 @@ def _get_constraints_for_augmented_periods(
431431
# look counterintuitive...
432432
aug_period_meas_type_to_constrain = (
433433
MeasurementType.STATES
434-
if endogenous_factors_info.factor_info[factor].is_state # ty: ignore[invalid-argument-type]
434+
if endogenous_factors_info.factor_info[factor].is_state
435435
else MeasurementType.ENDOGENOUS_FACTORS
436436
)
437437
aug_period_meas_types = (

src/skillmodels/correlation_heatmap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def plot_correlation_heatmap(
134134

135135
def get_measurements_corr(
136136
data: pd.DataFrame,
137-
model: dict | ModelSpec,
137+
model: ModelSpec,
138138
factors: list[str] | tuple[str, ...] | str | None,
139139
periods: float | list[int] | None,
140140
) -> pd.DataFrame:
@@ -178,7 +178,7 @@ def get_measurements_corr(
178178

179179
def get_quasi_scores_corr(
180180
data: pd.DataFrame,
181-
model: dict | ModelSpec,
181+
model: ModelSpec,
182182
factors: list[str] | tuple[str, ...] | str | None,
183183
periods: float | list[int] | None,
184184
) -> pd.DataFrame:
@@ -226,7 +226,7 @@ def get_quasi_scores_corr(
226226
def get_scores_corr(
227227
data: pd.DataFrame,
228228
params: pd.DataFrame,
229-
model: dict | ModelSpec,
229+
model: ModelSpec,
230230
factors: list[str] | tuple[str, ...] | str | None,
231231
periods: float | list[int] | None,
232232
) -> pd.DataFrame:

src/skillmodels/filtered_states.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def get_filtered_states(
18-
model: dict | ModelSpec,
18+
model: ModelSpec,
1919
data: pd.DataFrame,
2020
params: pd.DataFrame,
2121
) -> dict[str, dict[str, Any]]:
@@ -54,7 +54,7 @@ def get_filtered_states(
5454

5555
def anchor_states_df(
5656
states_df: pd.DataFrame,
57-
model: dict | ModelSpec,
57+
model: ModelSpec,
5858
params: pd.DataFrame,
5959
*,
6060
use_aug_period: bool,

src/skillmodels/maximization_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@
3232

3333

3434
def get_maximization_inputs(
35-
model: dict | ModelSpec,
35+
model: ModelSpec,
3636
data: pd.DataFrame,
3737
split_dataset: int = 1,
3838
) -> dict[str, Any]:
3939
"""Create inputs for optimagic's maximize function.
4040
4141
Args:
42-
model: The model specification, either as a dict or ModelSpec instance.
42+
model: The model specification as a ModelSpec instance.
4343
See: :ref:`model_specs`
4444
data: dataset in long format.
4545
split_dataset(Int): Controls into how many sclices to split the dataset

src/skillmodels/model_spec.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
33
This module provides frozen dataclasses for defining model specifications
44
in a type-safe, immutable manner. All collections use immutable types
5-
(tuples, frozendict) to ensure the specification cannot be accidentally modified.
5+
(tuples, MappingProxyType) to ensure the specification cannot be accidentally
6+
modified.
67
"""
78

89
from collections.abc import Callable
910
from dataclasses import dataclass, field
1011
from types import MappingProxyType
1112
from typing import Self
1213

13-
from frozendict import frozendict
14-
1514

1615
@dataclass(frozen=True)
1716
class Normalizations:
@@ -25,15 +24,15 @@ class Normalizations:
2524
2625
"""
2726

28-
loadings: tuple[frozendict[str, float], ...]
29-
intercepts: tuple[frozendict[str, float], ...]
27+
loadings: tuple[MappingProxyType[str, float], ...]
28+
intercepts: tuple[MappingProxyType[str, float], ...]
3029

3130
@classmethod
3231
def from_dict(cls, d: dict) -> Self:
3332
"""Create Normalizations from a dictionary specification."""
3433
return cls(
35-
loadings=tuple(frozendict(x) for x in d["loadings"]),
36-
intercepts=tuple(frozendict(x) for x in d["intercepts"]),
34+
loadings=tuple(MappingProxyType(x) for x in d["loadings"]),
35+
intercepts=tuple(MappingProxyType(x) for x in d["intercepts"]),
3736
)
3837

3938
def to_dict(self) -> dict:
@@ -169,8 +168,8 @@ def to_dict(self) -> dict:
169168
return result
170169

171170

172-
def _default_empty_frozendict() -> frozendict[str, str]:
173-
return frozendict({})
171+
def _default_empty_mapping_proxy() -> MappingProxyType[str, str]:
172+
return MappingProxyType({})
174173

175174

176175
@dataclass(frozen=True)
@@ -186,7 +185,9 @@ class AnchoringSpec:
186185
187186
"""
188187

189-
outcomes: frozendict[str, str] = field(default_factory=_default_empty_frozendict)
188+
outcomes: MappingProxyType[str, str] = field(
189+
default_factory=_default_empty_mapping_proxy,
190+
)
190191
free_controls: bool = False
191192
free_constant: bool = False
192193
free_loadings: bool = False
@@ -198,7 +199,7 @@ def from_dict(cls, d: dict) -> Self:
198199
outcomes = d.get("outcomes", {})
199200
ignore_constant = d.get("ignore_constant_when_anchoring", False)
200201
return cls(
201-
outcomes=frozendict(outcomes),
202+
outcomes=MappingProxyType(outcomes),
202203
free_controls=d.get("free_controls", False),
203204
free_constant=d.get("free_constant", False),
204205
free_loadings=d.get("free_loadings", False),

src/skillmodels/process_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _add_copies_of_anchoring_outcome(
177177
) -> pd.DataFrame:
178178
df = df.copy()
179179
for factor in anchoring_info.factors:
180-
outcome = anchoring_info.outcomes[factor] # ty: ignore[invalid-argument-type]
180+
outcome = anchoring_info.outcomes[factor]
181181
df[f"{outcome}_{factor}"] = df[outcome]
182182
return df
183183

0 commit comments

Comments
 (0)