11"""Tests of :mod:`.model.disutility`."""
22
3+ from collections .abc import Iterator
34from itertools import product
5+ from typing import TYPE_CHECKING
46
57import pandas as pd
68import pandas .testing as pdt
79import pytest
8- from message_ix import make_df
10+ from message_ix import Scenario , make_df
911from sdmx .model .common import Code
1012from sdmx .model .v21 import Annotation
1113
1820 merge_data ,
1921)
2022
23+ if TYPE_CHECKING :
24+ from message_ix_models .types import MutableParameterData , ParameterData
25+
2126# Common data and fixtures for test_minimal() and other tests
2227
2328COMMON = dict (
3540
3641
3742@pytest .fixture
38- def groups ():
43+ def groups () -> Iterator [ list [ Code ]] :
3944 """Fixture: list of 2 consumer groups."""
4045 yield [Code (id = "g0" ), Code (id = "g1" )]
4146
4247
4348@pytest .fixture
44- def techs ():
49+ def techs () -> Iterator [ list [ Code ]] :
4550 """Fixture: list of 2 technologies for which groups can have disutility."""
4651 yield [Code (id = "t0" ), Code (id = "t1" )]
4752
@@ -95,22 +100,28 @@ def test_add(scenario, groups, techs, template):
95100 assert (scenario .var ("ACT" )["lvl" ] == 0 ).all ()
96101
97102
98- def minimal_test_data (scenario ):
99- """Generate data for :func:`test_minimal`."""
103+ def minimal_test_data (scenario : Scenario ) -> tuple ["ParameterData" , int , int ]:
104+ """Generate data for :func:`test_minimal`.
105+
106+ - Two technologies: t0 and t1.
107+ - ``growth_activity_{lo,up}`` on both technologies.
108+ - t1 has slightly higher ``var_cost``, such that the model will prefer to maximize
109+ output of t0 within the constraint.
110+ """
100111 common = COMMON .copy ()
101112 common .pop ("node_loc" )
102113 common .update (dict (mode = "all" ))
103114
104- data = dict ()
115+ data : "MutableParameterData" = dict ()
105116
106117 info = ScenarioInfo (scenario )
107118 y0 = info .Y [0 ]
108119 y1 = info .Y [1 ]
109120
110121 # Output from t0 and t1
111- for t in ("t0" , "t1" ):
122+ for t , vc in ("t0" , 1.0 ), ( "t1" , 1.01 ):
112123 common .update (dict (technology = t , commodity = f"output of { t } " ))
113- merge_data (data , make_source_tech (info , common , output = 1.0 , var_cost = 1.0 ))
124+ merge_data (data , make_source_tech (info , common , output = 1.0 , var_cost = vc ))
114125
115126 # Disutility input for each combination of (tech) × (group) × (2 years)
116127 input_data = pd .DataFrame (
@@ -135,33 +146,34 @@ def minimal_test_data(scenario):
135146 data ["demand" ] = make_df ("demand" , commodity = c , year = y , value = 1.0 , ** COMMON )
136147
137148 # Constraint on activity in the first period
138- t = sorted (input_data ["technology" ].unique ())
149+ techs = sorted (input_data ["technology" ].unique ())
139150 for bound in ("lo" , "up" ):
140151 par = f"bound_activity_{ bound } "
141- data [par ] = make_df (par , value = 0.5 , technology = t , year_act = y0 , ** COMMON )
152+ data [par ] = make_df (par , value = 0.5 , technology = techs , year_act = y0 , ** COMMON )
142153
143154 # Constraint on activity growth
144155 annual = (1.1 ** (1.0 / 5.0 )) - 1.0
145156 for bound , factor in (("lo" , - 1.0 ), ("up" , 1.0 )):
146157 par = f"growth_activity_{ bound } "
147158 data [par ] = make_df (
148- par , value = factor * annual , technology = t , year_act = y1 , ** COMMON
159+ par , value = factor * annual , technology = techs , year_act = y1 , ** COMMON
149160 )
150161
151162 return data , y0 , y1
152163
153164
154- def test_minimal (scenario , groups , techs , template ):
165+ def test_minimal (
166+ scenario : Scenario , groups : list [Code ], techs : list [Code ], template : Code
167+ ) -> None :
155168 """Expected results are generated from a minimal test case."""
156- # Set up structure
169+ # Set up structure on `scenario`
157170 disutility .add (scenario , groups , techs , template )
158171
159172 # Add test-specific data
160173 data , y0 , y1 = minimal_test_data (scenario )
161174
162- scenario .check_out ()
163- add_par_data (scenario , data )
164- scenario .commit ("Disutility test 1" )
175+ with scenario .transact ("test_disutility.test_minimal case 1" ):
176+ add_par_data (scenario , data )
165177
166178 # commented: pre-solve debugging output
167179 # for par in ("input", "output", "technical_lifetime", "var_cost"):
@@ -196,9 +208,8 @@ def get_act(s):
196208
197209 # Re-solve
198210 scenario .remove_solution ()
199- scenario .check_out ()
200- scenario .add_par ("input" , data ["input" ])
201- scenario .commit ("Disutility test 2" )
211+ with scenario .transact ("test_disutility.test_minimal case 2" ):
212+ scenario .add_par ("input" , data ["input" ])
202213 scenario .solve (quiet = True )
203214
204215 # Compare activity
@@ -209,9 +220,6 @@ def get_act(s):
209220
210221 merged_delta = ACT1_delta .merge (ACT2_delta , left_index = True , right_index = True )
211222
212- # commented: for debugging
213- # print(merged, merged_delta)
214-
215223 # Group g0 decreases usage of t0, and increases usage of t1, in period y1 vs. y0
216224 assert merged_delta .loc ["usage of t0 by g0" , "lvl_y" ] < 0
217225 assert merged_delta .loc ["usage of t1 by g0" , "lvl_y" ] > 0
0 commit comments