From 3f527b81ff75c04db5ce8f47b42925a80422adf9 Mon Sep 17 00:00:00 2001 From: abrasive Date: Tue, 19 May 2026 13:20:33 +1000 Subject: [PATCH] support fields of type typing.Literal Closes #45 --- README.md | 9 +++++++++ src/dataclass_binder/_impl.py | 37 +++++++++++++++++++++++++++++++++-- tests/test_formatting.py | 11 +++++++++-- tests/test_parsing.py | 35 ++++++++++++++++++++++++++++++++- 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f9962a5..29f991c 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,15 @@ class Config: The `float` type can be used to bind floating point numbers. Support for `Decimal` is not there at the moment but would be relatively easy to add, as `tomllib`/`tomli` has an option for that. +Where only one of a specific set of values is permitted, `typing.Literal` can be used, which supports any combination of strings, ints and bools: + +```py +@dataclass +class Config: + mode: Literal["boring", "fancy", "mid"] + limbs: Literal[0, 1, 2, 3, 4, 5] +``` + ### Defaults Fields can be made optional by assigning a default value. Using `None` as a default value is allowed too: diff --git a/src/dataclass_binder/_impl.py b/src/dataclass_binder/_impl.py index 6ef6cb6..ab70897 100644 --- a/src/dataclass_binder/_impl.py +++ b/src/dataclass_binder/_impl.py @@ -25,7 +25,20 @@ from pathlib import Path from textwrap import dedent from types import GenericAlias, MappingProxyType, ModuleType, NoneType, UnionType -from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + ClassVar, + Generic, + Literal, + TypeVar, + Union, + cast, + get_args, + get_origin, + overload, + ) from weakref import WeakKeyDictionary if sys.version_info < (3, 11): @@ -75,6 +88,8 @@ def _collect_type(field_type: type, context: str) -> type | GenericAlias | Binde return collected_types[0] else: return reduce(operator.__or__, collected_types) + elif origin is Literal: + return field_type elif issubclass(origin, Mapping): type_args = get_args(field_type) try: @@ -318,6 +333,11 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) -> type(value) is not bool or field_type is bool or field_type is object ): return value + elif origin is Literal: + if value not in get_args(field_type): + valid_args = map(repr, get_args(field_type)) + raise TypeError(f"Value for '{context}' has value '{value}', expected one of {', '.join(valid_args)}") + return value elif issubclass(origin, Mapping): if not isinstance(value, dict): raise TypeError(f"Value for '{context}' has type '{type(value).__name__}', expected table") @@ -509,7 +529,9 @@ def _format_toml_table( continue origin = get_origin(field_type) if origin is not None: - if issubclass(origin, Mapping): + if origin is Literal: + pass + elif issubclass(origin, Mapping): _key_type, value_type = get_args(field_type) if isinstance(value_type, Binder): if value is None: @@ -887,6 +909,15 @@ def format_template(class_or_instance: Any) -> Iterator[str]: """Deprecated: use `Binder.format_toml_template()` instead.""" yield from Binder(class_or_instance).format_toml_template() +def _format_literal(literal: Any) -> str: + if isinstance(literal, str): + return f"'{literal}'" + elif isinstance(literal, bool): + return str(literal).lower() + elif isinstance(literal, int): + return str(literal) + else: + raise TypeError(f"Can't represent '{literal}' as a TOML constant") def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str: origin = get_origin(field_type) @@ -914,6 +945,8 @@ def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str: raise AssertionError(field_type) elif origin in (UnionType, Union): return " | ".join(_format_value_for_type(arg) for arg in get_args(field_type)) + elif origin is Literal: + return " | ".join(_format_literal(arg) for arg in get_args(field_type)) elif issubclass(origin, Mapping): return "{}" elif issubclass(origin, Iterable): diff --git a/tests/test_formatting.py b/tests/test_formatting.py index febba12..2b53e0e 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -6,7 +6,7 @@ from io import BytesIO from pathlib import Path from types import ModuleType, NoneType, UnionType -from typing import Any, TypeVar, Union, cast, get_args, get_origin +from typing import Any, Literal, TypeVar, Union, cast, get_args, get_origin import pytest @@ -319,6 +319,8 @@ class TemplateConfig: multi_type: str | int + literal: Literal["foo", 3, True] + derived: int = field(init=False) """Excluded field.""" @@ -365,6 +367,9 @@ def test_format_template_full() -> None: # Mandatory. multi-type = '???' | 0 + +# Mandatory. +literal = 'foo' | 3 | true """.strip() ) @@ -392,6 +397,7 @@ def test_format_dataclass_inline(*, optional: bool, string: bool) -> None: expiry=timedelta(days=3), certificate=Path("secrets/copper.key"), multi_type=-1, + literal=3, ) formatted = format_toml_pair("value", value) assert formatted == ( @@ -403,7 +409,8 @@ def test_format_dataclass_inline(*, optional: bool, string: bool) -> None: "another-number = 0.5, " "expiry-days = 3, " "certificate = 'secrets/copper.key', " - "multi-type = -1}" + "multi-type = -1, " + "literal = 3}" ) dc = single_value_dataclass(TemplateConfig, optional=optional, string=string) assert parse_toml(dc, formatted).value == value diff --git a/tests/test_parsing.py b/tests/test_parsing.py index 74f8eaa..c49d825 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -8,7 +8,7 @@ from io import BytesIO from pathlib import Path from types import ModuleType -from typing import Any, BinaryIO, Generic, TypeVar +from typing import Any, BinaryIO, Generic, Literal, TypeVar import pytest @@ -526,6 +526,39 @@ class MiniConfig: assert config.nested == (("abc", 2), (True, "def")) +@dataclass(frozen=True) +class LiteralConfig: + magic_word: Literal["abracadabra", "opensesame"] + +def test_bind_literal() -> None: + """typing.Literal is a valid field type""" + + for word in ["abracadabra", "opensesame"]: + with stream_text( + f""" + magic-word = "{word}" + """ + ) as stream: + config = Binder(LiteralConfig).parse_toml(stream) + + assert config.magic_word == word + +def test_bind_invalid_literal() -> None: + """invalid values of typing.Literal are not accepted""" + + with ( + stream_text( + """ + magic-word = "dooverlacky" + """ + ) as stream, + pytest.raises( + TypeError, + match=r"^Value for 'LiteralConfig.magic_word' has value 'dooverlacky', " + "expected one of 'abracadabra', 'opensesame'$", + ), + ): + Binder(LiteralConfig).parse_toml(stream) @dataclass(frozen=True) class MappingConfig: