Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ class Config:
The `float` type can be used to bind floating point numbers.
Support for `Decimal` is not there at the moment but would be relatively easy to add, as `tomllib`/`tomli` has an option for that.

Where only one of a specific set of values is permitted, `typing.Literal` can be used, which supports any combination of strings, ints and bools:

```py
@dataclass
class Config:
mode: Literal["boring", "fancy", "mid"]
limbs: Literal[0, 1, 2, 3, 4, 5]
```

### Defaults

Fields can be made optional by assigning a default value. Using `None` as a default value is allowed too:
Expand Down
37 changes: 35 additions & 2 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,20 @@
from pathlib import Path
from textwrap import dedent
from types import GenericAlias, MappingProxyType, ModuleType, NoneType, UnionType
from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
ClassVar,
Generic,
Literal,
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)
from weakref import WeakKeyDictionary

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -75,6 +88,8 @@ def _collect_type(field_type: type, context: str) -> type | GenericAlias | Binde
return collected_types[0]
else:
return reduce(operator.__or__, collected_types)
elif origin is Literal:
return field_type
elif issubclass(origin, Mapping):
type_args = get_args(field_type)
try:
Expand Down Expand Up @@ -318,6 +333,11 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) ->
type(value) is not bool or field_type is bool or field_type is object
):
return value
elif origin is Literal:
if value not in get_args(field_type):
valid_args = map(repr, get_args(field_type))
raise TypeError(f"Value for '{context}' has value '{value}', expected one of {', '.join(valid_args)}")
return value
elif issubclass(origin, Mapping):
if not isinstance(value, dict):
raise TypeError(f"Value for '{context}' has type '{type(value).__name__}', expected table")
Expand Down Expand Up @@ -509,7 +529,9 @@ def _format_toml_table(
continue
origin = get_origin(field_type)
if origin is not None:
if issubclass(origin, Mapping):
if origin is Literal:
pass
elif issubclass(origin, Mapping):
_key_type, value_type = get_args(field_type)
if isinstance(value_type, Binder):
if value is None:
Expand Down Expand Up @@ -887,6 +909,15 @@ def format_template(class_or_instance: Any) -> Iterator[str]:
"""Deprecated: use `Binder.format_toml_template()` instead."""
yield from Binder(class_or_instance).format_toml_template()

def _format_literal(literal: Any) -> str:
if isinstance(literal, str):
return f"'{literal}'"
elif isinstance(literal, bool):
return str(literal).lower()
elif isinstance(literal, int):
return str(literal)
else:
raise TypeError(f"Can't represent '{literal}' as a TOML constant")

def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str:
origin = get_origin(field_type)
Expand Down Expand Up @@ -914,6 +945,8 @@ def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str:
raise AssertionError(field_type)
elif origin in (UnionType, Union):
return " | ".join(_format_value_for_type(arg) for arg in get_args(field_type))
elif origin is Literal:
return " | ".join(_format_literal(arg) for arg in get_args(field_type))
elif issubclass(origin, Mapping):
return "{}"
elif issubclass(origin, Iterable):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from io import BytesIO
from pathlib import Path
from types import ModuleType, NoneType, UnionType
from typing import Any, TypeVar, Union, cast, get_args, get_origin
from typing import Any, Literal, TypeVar, Union, cast, get_args, get_origin

import pytest

Expand Down Expand Up @@ -319,6 +319,8 @@ class TemplateConfig:

multi_type: str | int

literal: Literal["foo", 3, True]

derived: int = field(init=False)
"""Excluded field."""

Expand Down Expand Up @@ -365,6 +367,9 @@ def test_format_template_full() -> None:

# Mandatory.
multi-type = '???' | 0

# Mandatory.
literal = 'foo' | 3 | true
""".strip()
)

Expand Down Expand Up @@ -392,6 +397,7 @@ def test_format_dataclass_inline(*, optional: bool, string: bool) -> None:
expiry=timedelta(days=3),
certificate=Path("secrets/copper.key"),
multi_type=-1,
literal=3,
)
formatted = format_toml_pair("value", value)
assert formatted == (
Expand All @@ -403,7 +409,8 @@ def test_format_dataclass_inline(*, optional: bool, string: bool) -> None:
"another-number = 0.5, "
"expiry-days = 3, "
"certificate = 'secrets/copper.key', "
"multi-type = -1}"
"multi-type = -1, "
"literal = 3}"
)
dc = single_value_dataclass(TemplateConfig, optional=optional, string=string)
assert parse_toml(dc, formatted).value == value
Expand Down
35 changes: 34 additions & 1 deletion tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from io import BytesIO
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Generic, TypeVar
from typing import Any, BinaryIO, Generic, Literal, TypeVar

import pytest

Expand Down Expand Up @@ -526,6 +526,39 @@ class MiniConfig:

assert config.nested == (("abc", 2), (True, "def"))

@dataclass(frozen=True)
class LiteralConfig:
magic_word: Literal["abracadabra", "opensesame"]

def test_bind_literal() -> None:
"""typing.Literal is a valid field type"""

for word in ["abracadabra", "opensesame"]:
with stream_text(
f"""
magic-word = "{word}"
"""
) as stream:
config = Binder(LiteralConfig).parse_toml(stream)

assert config.magic_word == word

def test_bind_invalid_literal() -> None:
"""invalid values of typing.Literal are not accepted"""

with (
stream_text(
"""
magic-word = "dooverlacky"
"""
) as stream,
pytest.raises(
TypeError,
match=r"^Value for 'LiteralConfig.magic_word' has value 'dooverlacky', "
"expected one of 'abracadabra', 'opensesame'$",
),
):
Binder(LiteralConfig).parse_toml(stream)

@dataclass(frozen=True)
class MappingConfig:
Expand Down