Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-useless-excludes
# - id: identity # Prints all files passed to pre-commits. Debugging.
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.37.1
rev: v1.38.0
hooks:
- id: yamllint
- repo: https://github.com/lyz-code/yamlfix
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/pycqa/isort
rev: 7.0.0
rev: 8.0.1
hooks:
- id: isort
name: isort
Expand All @@ -55,13 +55,13 @@ repos:
# args:
# - --py37-plus
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
rev: 26.1.0
hooks:
- id: black
language_version: python3.12
exclude: tests/utils/fast_upper_envelope_org.py
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.10
rev: v0.15.4
hooks:
- id: ruff
# exclude: |
Expand Down
15 changes: 8 additions & 7 deletions docs/time_period2_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
from numba import njit

import upper_envelope as upenv
import upper_envelope.jax as ue_jax
import upper_envelope.numba as ue_numba

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -60,7 +61,7 @@ def value_func_jax(consumption, choice, params):


def fues_jax_partial(endog, pol, val, exp_val_zero):
return upenv.fues_jax(
return ue_jax.fues_jax(
endog_grid=jnp.asarray(endog),
policy=jnp.asarray(pol),
value=jnp.asarray(val),
Expand Down Expand Up @@ -103,7 +104,7 @@ def fues_jax_partial(endog, pol, val, exp_val_zero):


def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
return upenv.drued_jorg_jax(
return ue_jax.drued_jorg_jax(
endog_grid=endog,
policy=pol,
value=val,
Expand Down Expand Up @@ -168,7 +169,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
# Numba FUES
start = time.time()
jax.block_until_ready(
upenv.fues_numba(
ue_numba.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand All @@ -184,7 +185,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
upenv.fues_numba(
ue_numba.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand All @@ -201,7 +202,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
# Numba DRUED-JORG
start = time.time()
jax.block_until_ready(
upenv.drued_jorg_numba(
ue_numba.drued_jorg_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand All @@ -218,7 +219,7 @@ def drued_jorg_jax_partial(endog, pol, val, m_grid, exp_val_zero):
for _ in range(n_runs):
start = time.time()
jax.block_until_ready(
upenv.drued_jorg_numba(
ue_numba.drued_jorg_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand Down
22 changes: 11 additions & 11 deletions docs/tutorials/ue_drued_jorg.ipynb

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
# Project metadata
# ======================================================================================
[project]
name = "upper_envelope"
description = "Upper envelope scan for dynamic discrete-continuous life cycle models."
version = "0.1.3"
requires-python = ">=3.10"
name = "upper_envelope"
description = "Upper envelope scan for dynamic discrete-continuous life cycle models."
dynamic = ["version"]
requires-python = ">=3.10"
dependencies = [
"numpy",
"pandas",
"scipy",
"jax"
]
keywords = [
"Dynamic programming",
Expand All @@ -32,7 +29,7 @@ classifiers = [
"Topic :: Scientific/Engineering",
]
authors = [
{ name="Max Blesch", email="maximilian.blesch@hu-berlin.de" },
{ name="Max Blesch", email="maxblesch@gmail.com" },
{ name="Sebastian Gsell", email="gsell.sebastian@gmail.com" },
]
maintainers = [
Expand All @@ -55,7 +52,7 @@ Github = "https://github.com/OpensourceEconomics/upper-envelope"
# ======================================================================================

[build-system]
requires = ["hatchling", "hatch_vcs"]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.build.hooks.vcs]
Expand Down
4 changes: 0 additions & 4 deletions src/upper_envelope/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from upper_envelope.drued_jorg_jax import drued_jorg_jax
from upper_envelope.drued_jorg_numba import drued_jorg_numba
from upper_envelope.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained
from upper_envelope.fues_numba.fues_numba import fues_numba, fues_numba_unconstrained
2 changes: 2 additions & 0 deletions src/upper_envelope/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from upper_envelope.jax.drued_jorg_jax import drued_jorg_jax
from upper_envelope.jax.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import jax.numpy as jnp
from jax import vmap

from upper_envelope.fues_jax.check_and_scan_funcs import (
from upper_envelope.jax.fues_jax.check_and_scan_funcs import (
determine_cases_and_conduct_necessary_scans,
)
from upper_envelope.math_funcs import calc_intersection_and_extrapolate_policy
Expand Down
5 changes: 5 additions & 0 deletions src/upper_envelope/numba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from upper_envelope.numba.drued_jorg_numba import drued_jorg_numba
from upper_envelope.numba.fues_numba.fues_numba import (
fues_numba,
fues_numba_unconstrained,
)
7 changes: 4 additions & 3 deletions tests/test_drued_jorg_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from numba import njit
from numpy.testing import assert_allclose

import upper_envelope as upenv
import upper_envelope.jax as upenv_jax
import upper_envelope.numba as upenv_numba
from tests.utils.comparison_interp import interpolate_on_safe_reference_segments

TEST_DIR = Path(__file__).parent
Expand Down Expand Up @@ -69,7 +70,7 @@ def value_func_jax(consumption, choice, params):
utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0]
)

ref_m, ref_c, ref_v = upenv.fues_jax(
ref_m, ref_c, ref_v = upenv_jax.fues_jax(
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
Expand All @@ -91,7 +92,7 @@ def value_func_jax(consumption, choice, params):
m_max = float(np.max(policy_egm[0, 1:]))
m_grid = np.linspace(m_min, m_max, 500)

endog_out, policy_out, value_out = upenv.drued_jorg_numba(
endog_out, policy_out, value_out = upenv_numba.drued_jorg_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand Down
15 changes: 9 additions & 6 deletions tests/test_fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import pytest
from numpy.testing import assert_array_almost_equal as aaae

import upper_envelope as upenv
import upper_envelope.jax as upenv_jax
import upper_envelope.numba as upenv_numba
from tests.utils.interpolation import (
interpolate_policy_and_value_on_wealth_grid,
linear_interpolation_with_extrapolation,
)
from tests.utils.upper_envelope_fedor import upper_envelope
from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper
from upper_envelope.jax.fues_jax.check_and_scan_funcs import (
back_and_forward_scan_wrapper,
)

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -134,7 +137,7 @@ def value_func(consumption, choice, params):
endog_grid_refined,
policy_refined,
value_refined,
) = upenv.fues_jax(
) = upenv_jax.fues_jax(
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
Expand Down Expand Up @@ -180,7 +183,7 @@ def test_fast_upper_envelope_against_numba(setup_model):
)
_params, exog_savings_grid, state_choice_vars = setup_model

endog_grid_org, value_org, policy_org = upenv.fues_numba_unconstrained(
endog_grid_org, value_org, policy_org = upenv_numba.fues_numba_unconstrained(
endog_grid=policy_egm[0],
value=value_egm[1],
policy=policy_egm[1],
Expand All @@ -190,7 +193,7 @@ def test_fast_upper_envelope_against_numba(setup_model):
endog_grid_refined,
value_refined,
policy_refined,
) = jax.jit(upenv.fues_jax_unconstrained)(
) = jax.jit(upenv_jax.fues_jax_unconstrained)(
endog_grid=policy_egm[0, 1:],
value=value_egm[1, 1:],
policy=policy_egm[1, 1:],
Expand Down Expand Up @@ -262,7 +265,7 @@ def value_func(consumption, choice, params):
endog_grid_fues,
policy_fues,
value_fues,
) = upenv.fues_jax(
) = upenv_jax.fues_jax(
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
Expand Down
17 changes: 10 additions & 7 deletions tests/test_fues_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from numba import njit
from numpy.testing import assert_array_almost_equal as aaae

import upper_envelope as upenv
import upper_envelope.jax as upenv_jax
import upper_envelope.numba as upenv_numba
from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org
from tests.utils.interpolation import (
interpolate_single_policy_and_value_on_wealth_grid,
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_fast_upper_envelope_wrapper(period, setup_model):

params, state_choice_vec, _exog_savings_grid = setup_model

endog_grid_refined, policy_refined, value_refined = upenv.fues_numba(
endog_grid_refined, policy_refined, value_refined = upenv_numba.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand Down Expand Up @@ -155,10 +156,12 @@ def test_fast_upper_envelope_against_org_fues(setup_model):

_params, state_choice_vec, exog_savings_grid = setup_model

endog_grid_refined, value_refined, policy_refined = upenv.fues_numba_unconstrained(
endog_grid=policy_egm[0],
value=value_egm[1],
policy=policy_egm[1],
endog_grid_refined, value_refined, policy_refined = (
upenv_numba.fues_numba_unconstrained(
endog_grid=policy_egm[0],
value=value_egm[1],
policy=policy_egm[1],
)
)
endog_grid_org, policy_org, value_org = fast_upper_envelope_wrapper_org(
endog_grid=policy_egm[0],
Expand Down Expand Up @@ -205,7 +208,7 @@ def test_fast_upper_envelope_against_fedor(period, setup_model):
~np.isnan(_value_fedor).any(axis=0),
]

endog_grid_fues, policy_fues, value_fues = upenv.fues_numba(
endog_grid_fues, policy_fues, value_fues = upenv_numba.fues_numba(
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
Expand Down
6 changes: 3 additions & 3 deletions tests/test_jorg_drued_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from comparison_interp import interpolate_on_safe_reference_segments
from numpy.testing import assert_allclose

import upper_envelope as upenv
import upper_envelope.jax as upenv_jax

TEST_DIR = Path(__file__).parent
TEST_RESOURCES_DIR = TEST_DIR / "resources"
Expand Down Expand Up @@ -68,7 +68,7 @@ def value_func(consumption, choice, params):
utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0]
)

ref_m, ref_c, ref_v = upenv.fues_jax(
ref_m, ref_c, ref_v = upenv_jax.fues_jax(
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
Expand All @@ -93,7 +93,7 @@ def value_func(consumption, choice, params):
m_max = float(np.max(policy_egm[0, 1:]))
m_grid = np.linspace(m_min, m_max, 500)

endog_out, policy_out, value_out = upenv.drued_jorg_jax(
endog_out, policy_out, value_out = upenv_jax.drued_jorg_jax(
endog_grid=jnp.asarray(policy_egm[0, 1:]),
policy=jnp.asarray(policy_egm[1, 1:]),
value=jnp.asarray(value_egm[1, 1:]),
Expand Down
Loading