Skip to content

Commit e45db65

Browse files
committed
Adjust test_disutility for message_ix#451
- Closes #397. - Expand type hints.
1 parent 5c5a50f commit e45db65

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

message_ix_models/tests/model/test_disutility.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Tests of :mod:`.model.disutility`."""
22

3+
from collections.abc import Iterator
34
from itertools import product
5+
from typing import TYPE_CHECKING
46

57
import pandas as pd
68
import pandas.testing as pdt
79
import pytest
8-
from message_ix import make_df
10+
from message_ix import Scenario, make_df
911
from sdmx.model.common import Code
1012
from sdmx.model.v21 import Annotation
1113

@@ -18,6 +20,9 @@
1820
merge_data,
1921
)
2022

23+
if TYPE_CHECKING:
24+
from message_ix_models.types import MutableParameterData, ParameterData
25+
2126
# Common data and fixtures for test_minimal() and other tests
2227

2328
COMMON = dict(
@@ -35,13 +40,13 @@
3540

3641

3742
@pytest.fixture
38-
def groups():
43+
def groups() -> Iterator[list[Code]]:
3944
"""Fixture: list of 2 consumer groups."""
4045
yield [Code(id="g0"), Code(id="g1")]
4146

4247

4348
@pytest.fixture
44-
def techs():
49+
def techs() -> Iterator[list[Code]]:
4550
"""Fixture: list of 2 technologies for which groups can have disutility."""
4651
yield [Code(id="t0"), Code(id="t1")]
4752

@@ -95,22 +100,28 @@ def test_add(scenario, groups, techs, template):
95100
assert (scenario.var("ACT")["lvl"] == 0).all()
96101

97102

98-
def minimal_test_data(scenario):
99-
"""Generate data for :func:`test_minimal`."""
103+
def minimal_test_data(scenario: Scenario) -> tuple["ParameterData", int, int]:
104+
"""Generate data for :func:`test_minimal`.
105+
106+
- Two technologies: t0 and t1.
107+
- ``growth_activity_{lo,up}`` on both technologies.
108+
- t1 has slightly higher ``var_cost``, such that the model will prefer to maximize
109+
output of t0 within the constraint.
110+
"""
100111
common = COMMON.copy()
101112
common.pop("node_loc")
102113
common.update(dict(mode="all"))
103114

104-
data = dict()
115+
data: "MutableParameterData" = dict()
105116

106117
info = ScenarioInfo(scenario)
107118
y0 = info.Y[0]
108119
y1 = info.Y[1]
109120

110121
# Output from t0 and t1
111-
for t in ("t0", "t1"):
122+
for t, vc in ("t0", 1.0), ("t1", 1.01):
112123
common.update(dict(technology=t, commodity=f"output of {t}"))
113-
merge_data(data, make_source_tech(info, common, output=1.0, var_cost=1.0))
124+
merge_data(data, make_source_tech(info, common, output=1.0, var_cost=vc))
114125

115126
# Disutility input for each combination of (tech) × (group) × (2 years)
116127
input_data = pd.DataFrame(
@@ -135,33 +146,34 @@ def minimal_test_data(scenario):
135146
data["demand"] = make_df("demand", commodity=c, year=y, value=1.0, **COMMON)
136147

137148
# Constraint on activity in the first period
138-
t = sorted(input_data["technology"].unique())
149+
techs = sorted(input_data["technology"].unique())
139150
for bound in ("lo", "up"):
140151
par = f"bound_activity_{bound}"
141-
data[par] = make_df(par, value=0.5, technology=t, year_act=y0, **COMMON)
152+
data[par] = make_df(par, value=0.5, technology=techs, year_act=y0, **COMMON)
142153

143154
# Constraint on activity growth
144155
annual = (1.1 ** (1.0 / 5.0)) - 1.0
145156
for bound, factor in (("lo", -1.0), ("up", 1.0)):
146157
par = f"growth_activity_{bound}"
147158
data[par] = make_df(
148-
par, value=factor * annual, technology=t, year_act=y1, **COMMON
159+
par, value=factor * annual, technology=techs, year_act=y1, **COMMON
149160
)
150161

151162
return data, y0, y1
152163

153164

154-
def test_minimal(scenario, groups, techs, template):
165+
def test_minimal(
166+
scenario: Scenario, groups: list[Code], techs: list[Code], template: Code
167+
) -> None:
155168
"""Expected results are generated from a minimal test case."""
156-
# Set up structure
169+
# Set up structure on `scenario`
157170
disutility.add(scenario, groups, techs, template)
158171

159172
# Add test-specific data
160173
data, y0, y1 = minimal_test_data(scenario)
161174

162-
scenario.check_out()
163-
add_par_data(scenario, data)
164-
scenario.commit("Disutility test 1")
175+
with scenario.transact("test_disutility.test_minimal case 1"):
176+
add_par_data(scenario, data)
165177

166178
# commented: pre-solve debugging output
167179
# for par in ("input", "output", "technical_lifetime", "var_cost"):
@@ -196,9 +208,8 @@ def get_act(s):
196208

197209
# Re-solve
198210
scenario.remove_solution()
199-
scenario.check_out()
200-
scenario.add_par("input", data["input"])
201-
scenario.commit("Disutility test 2")
211+
with scenario.transact("test_disutility.test_minimal case 2"):
212+
scenario.add_par("input", data["input"])
202213
scenario.solve(quiet=True)
203214

204215
# Compare activity
@@ -209,9 +220,6 @@ def get_act(s):
209220

210221
merged_delta = ACT1_delta.merge(ACT2_delta, left_index=True, right_index=True)
211222

212-
# commented: for debugging
213-
# print(merged, merged_delta)
214-
215223
# Group g0 decreases usage of t0, and increases usage of t1, in period y1 vs. y0
216224
assert merged_delta.loc["usage of t0 by g0", "lvl_y"] < 0
217225
assert merged_delta.loc["usage of t1 by g0", "lvl_y"] > 0

0 commit comments

Comments
 (0)