Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.4
with:
pixi-version: v0.64.0
pixi-version: v0.65.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
frozen: true
Expand Down
6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ releases are available on [conda-forge](https://anaconda.org/conda-forge/dags).

## 0.5.0

- :gh:`76` Add a couple of tests to bring coverage to 100% (:ghuser:`hmgaudecker`).

- :gh:`75` Streamline public API (:ghuser:`hmgaudecker`).

- Deprecate `one_function_without_tree_logic`,
`functions_without_tree_logic`, and `fail_if_paths_are_invalid`
in favor of `get_one_function_without_tree_logic`,
`get_functions_without_tree_logic`, and
`dags.tree.validation.fail_if_paths_are_invalid`.
`get_functions_without_tree_logic`
- Deprecate `dags.tree.fail_if_paths_are_invalid` (call from `dags.tree.validation`)

- :gh:`65` Update docs and use Jupyter Book for documentation (:ghuser:`hmgaudecker`).

Expand Down
63 changes: 63 additions & 0 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,66 @@ def f2():

assert aggregated() is False
assert "return" not in inspect.get_annotations(aggregated)


def test_concatenate_functions_invalid_return_type_raises() -> None:
with pytest.raises(DagsError, match="Invalid return type"):
concatenate_functions(
functions=[_leisure, _consumption],
targets=["_leisure", "_consumption"],
return_type="set", # type: ignore[arg-type]
)


def test_get_ancestors_include_targets() -> None:
calculated = get_ancestors(
functions=[_utility, _unrelated, _leisure, _consumption],
targets="_utility",
include_targets=True,
)
expected = {
"_utility",
"_consumption",
"_leisure",
"working_hours",
"wage",
"leisure_weight",
}
assert calculated == expected


def test_concatenate_functions_non_string_targets() -> None:
with pytest.raises(DagsError, match="Targets must be strings"):
concatenate_functions(
functions={"f": lambda: 1},
targets=[1], # type: ignore[list-item]
)


def test_aggregator_exception_in_get_annotations() -> None:
"""Test that an aggregator whose annotations cause an exception is handled."""

def f1() -> int:
return 1

def f2() -> int:
return 2

# Create an object that is callable but raises on get_annotations
class BadAggregator:
def __call__(self, a: int, b: int) -> int:
return a + b

@property
def __annotations__(self) -> dict[str, Any]:
msg = "bad annotations"
raise TypeError(msg)

aggregator = BadAggregator()
result = concatenate_functions(
functions={"f1": f1, "f2": f2},
targets=["f1", "f2"],
aggregator=aggregator,
set_annotations=True,
)
assert result() == 3
14 changes: 14 additions & 0 deletions tests/test_dag_tree/test_dag_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from dags.tree import concatenate_functions_tree
from dags.tree.dag_tree import create_dag_tree
from dags.tree.typing import (
NestedFunctionDict,
NestedInputDict,
Expand Down Expand Up @@ -176,3 +177,16 @@ def f(a, b):
enforce_signature=True,
)
assert concatenated_func({"a": 1}) == {"f": 2}


def test_create_dag_tree(functions_simple: NestedFunctionDict) -> None:
inputs: NestedInputDict = {
"n1": {"a": 1, "b": 2},
"n2": {"a": 3, "b": {"g": 4}},
}
targets: NestedTargetDict = {"n1": {"f": None}}

dag = create_dag_tree(functions=functions_simple, inputs=inputs, targets=targets)

assert "n1__f" in dag.nodes
assert "n1__g" in dag.nodes
27 changes: 27 additions & 0 deletions tests/test_process_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from dags.exceptions import DagsError
from dags.output import (
aggregated_output,
dict_output,
Expand Down Expand Up @@ -88,3 +89,29 @@ def f():
return (1, 2)

assert f() == [1, 2]


def test_dict_output_keys_none() -> None:
with pytest.raises(DagsError, match="'keys' parameter is required"):
dict_output(keys=None) # ty: ignore[invalid-argument-type]


def test_aggregated_output_aggregator_none() -> None:
with pytest.raises(DagsError, match="'aggregator' parameter is required"):
aggregated_output(aggregator=None) # ty: ignore[invalid-argument-type]


def test_aggregated_output_direct_call() -> None:
def f():
return (10, 20)

g = aggregated_output(f, aggregator=lambda x, y: x + y)
assert g() == 30


def test_aggregated_output_decorator_usage() -> None:
@aggregated_output(aggregator=lambda x, y: x + y)
def f():
return (10, 20)

assert f() == 30
8 changes: 8 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,11 @@ def test_with_signature_invalid_args_type() -> None:
@with_signature(args="invalid")
def f(*args, **kwargs):
pass


def test_with_signature_invalid_args_type_int() -> None:
with pytest.raises(DagsError, match="Invalid type for arg"):

@with_signature(args=42) # type: ignore[arg-type]
def f(*args, **kwargs):
pass
Loading