Skip to content

Commit 502cf62

Browse files
committed
allow unwrapping in decorated functions
1 parent 995d9cc commit 502cf62

File tree

2 files changed

+142
-20
lines changed

2 files changed

+142
-20
lines changed

src/result/result.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,11 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[R, TBE]:
486486
return Ok(f(*args, **kwargs))
487487
except exceptions as exc:
488488
return Err(exc)
489+
except UnwrapError as ue:
490+
if ue.__cause__ is not None and isinstance(ue.__cause__, exceptions):
491+
return Err(ue.__cause__)
492+
493+
raise
489494

490495
return wrapper
491496

@@ -519,6 +524,11 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[R, TBE]:
519524
return Ok(await f(*args, **kwargs))
520525
except exceptions as exc:
521526
return Err(exc)
527+
except UnwrapError as ue:
528+
if ue.__cause__ is not None and isinstance(ue.__cause__, exceptions):
529+
return Err(ue.__cause__)
530+
531+
raise
522532

523533
return async_wrapper
524534

@@ -553,6 +563,12 @@ def wrapper(
553563
except exceptions as exc:
554564
yield Err(exc)
555565
return None
566+
except UnwrapError as ue:
567+
if ue.__cause__ is not None and isinstance(ue.__cause__, exceptions):
568+
yield Err(ue.__cause__)
569+
return None
570+
571+
raise
556572

557573
send_value = yield Ok(first_bit)
558574

@@ -564,6 +580,14 @@ def wrapper(
564580
except exceptions as exc:
565581
send_value = yield Err(exc)
566582
return None
583+
except UnwrapError as ue:
584+
if ue.__cause__ is not None and isinstance(
585+
ue.__cause__, exceptions
586+
):
587+
yield Err(ue.__cause__)
588+
return None
589+
590+
raise
567591

568592
send_value = yield Ok(yield_value)
569593

@@ -600,6 +624,12 @@ async def async_wrapper(
600624
except exceptions as exc:
601625
yield Err(exc)
602626
return
627+
except UnwrapError as ue:
628+
if ue.__cause__ is not None and isinstance(ue.__cause__, exceptions):
629+
yield Err(ue.__cause__)
630+
return
631+
632+
raise
603633

604634
send_value = yield Ok(first_bit)
605635

@@ -611,6 +641,14 @@ async def async_wrapper(
611641
except exceptions as exc:
612642
send_value = yield Err(exc)
613643
return
644+
except UnwrapError as ue:
645+
if ue.__cause__ is not None and isinstance(
646+
ue.__cause__, exceptions
647+
):
648+
yield Err(ue.__cause__)
649+
return
650+
651+
raise
614652

615653
send_value = yield Ok(yield_value)
616654

tests/test_result.py

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,31 @@
1717
)
1818

1919

20+
def sq(i: int) -> Result[int, int]:
21+
return Ok(i * i)
22+
23+
24+
async def sq_async(i: int) -> Result[int, int]:
25+
return Ok(i * i)
26+
27+
28+
def to_err(i: int) -> Result[int, int]:
29+
return Err(i)
30+
31+
32+
async def to_err_async(i: int) -> Result[int, int]:
33+
return Err(i)
34+
35+
36+
# Lambda versions of the same functions, just for test/type coverage
37+
def sq_lambda(i: int) -> Result[int, int]:
38+
return Ok(i * i)
39+
40+
41+
def to_err_lambda(i: int) -> Result[int, int]:
42+
return Err(i)
43+
44+
2045
def test_ok_factories() -> None:
2146
instance = Ok(1)
2247
assert instance._value == 1
@@ -367,6 +392,24 @@ def f() -> int:
367392
f()
368393

369394

395+
def test_as_result_unwraps() -> None:
396+
@as_result(ValueError)
397+
def raises_unwrapping_error(value: int) -> int:
398+
Err(IndexError("Test Error")).unwrap()
399+
return value
400+
401+
@as_result(IndexError, ValueError)
402+
def does_not_raise_unwrapping_error(value: int) -> int:
403+
Err(IndexError("Test Error")).unwrap()
404+
return value
405+
406+
with pytest.raises(UnwrapError):
407+
raises_unwrapping_error(123)
408+
409+
err_response = does_not_raise_unwrapping_error(123)
410+
assert isinstance(err_response.unwrap_err(), IndexError)
411+
412+
370413
def test_as_result_invalid_usage() -> None:
371414
"""
372415
Invalid use of ``as_result()`` raises reasonable errors.
@@ -423,29 +466,23 @@ async def bad(value: int) -> int:
423466
assert isinstance(bad_result.unwrap_err(), ValueError)
424467

425468

426-
def sq(i: int) -> Result[int, int]:
427-
return Ok(i * i)
428-
429-
430-
async def sq_async(i: int) -> Result[int, int]:
431-
return Ok(i * i)
432-
433-
434-
def to_err(i: int) -> Result[int, int]:
435-
return Err(i)
436-
469+
@pytest.mark.asyncio
470+
async def test_as_async_result_unwraps() -> None:
471+
@as_async_result(ValueError)
472+
async def raises_unwrapping_error(value: int) -> int:
473+
Err(IndexError("Test Error")).unwrap()
474+
return value
437475

438-
async def to_err_async(i: int) -> Result[int, int]:
439-
return Err(i)
476+
@as_async_result(IndexError, ValueError)
477+
async def does_not_raise_unwrapping_error(value: int) -> int:
478+
Err(IndexError("Test Error")).unwrap()
479+
return value
440480

481+
with pytest.raises(UnwrapError):
482+
await raises_unwrapping_error(123)
441483

442-
# Lambda versions of the same functions, just for test/type coverage
443-
def sq_lambda(i: int) -> Result[int, int]:
444-
return Ok(i * i)
445-
446-
447-
def to_err_lambda(i: int) -> Result[int, int]:
448-
return Err(i)
484+
err_response = await does_not_raise_unwrapping_error(123)
485+
assert isinstance(err_response.unwrap_err(), IndexError)
449486

450487

451488
def test_as_generator_result_ok() -> None:
@@ -478,6 +515,29 @@ def my_generator(val: int) -> Generator[int, None, None]:
478515
next(result)
479516

480517

518+
def test_as_generator_result_unwraps() -> None:
519+
@as_generator_result(ValueError)
520+
def raises_unwrapping_error(value: int) -> Generator[int, None, None]:
521+
Err(IndexError("Test Error")).unwrap()
522+
yield 5
523+
524+
@as_generator_result(IndexError, ValueError)
525+
def does_not_raise_unwrapping_error(value: int) -> Generator[int, None, None]:
526+
yield 3
527+
Err(IndexError("Test Error")).unwrap()
528+
yield 5
529+
530+
raising_generator = raises_unwrapping_error(123)
531+
with pytest.raises(UnwrapError):
532+
next(raising_generator)
533+
534+
running_generator = does_not_raise_unwrapping_error(123)
535+
assert next(running_generator) == Ok(3)
536+
assert next(running_generator).unwrap_err().args[0] == "Test Error"
537+
with pytest.raises(StopIteration):
538+
next(running_generator)
539+
540+
481541
def test_as_generator_result_with_send() -> None:
482542
@as_generator_result(ValueError)
483543
def my_generator() -> Generator[int, int, None]:
@@ -551,6 +611,30 @@ async def my_generator(val: int) -> AsyncGenerator[int, None]:
551611
await anext(result)
552612

553613

614+
@pytest.mark.asyncio
615+
async def test_as_async_generator_result_unwraps() -> None:
616+
@as_async_generator_result(ValueError)
617+
async def raises_unwrapping_error(value: int) -> AsyncGenerator[int, None]:
618+
Err(IndexError("Test Error")).unwrap()
619+
yield 5
620+
621+
@as_async_generator_result(IndexError, ValueError)
622+
async def does_not_raise_unwrapping_error(value: int) -> AsyncGenerator[int, None]:
623+
yield 3
624+
Err(IndexError("Test Error")).unwrap()
625+
yield 5
626+
627+
raising_generator = raises_unwrapping_error(123)
628+
with pytest.raises(UnwrapError):
629+
await anext(raising_generator)
630+
631+
running_generator = does_not_raise_unwrapping_error(123)
632+
assert await anext(running_generator) == Ok(3)
633+
assert (await anext(running_generator)).unwrap_err().args[0] == "Test Error"
634+
with pytest.raises(StopAsyncIteration):
635+
await anext(running_generator)
636+
637+
554638
@pytest.mark.asyncio
555639
async def test_as_async_generator_result_with_send() -> None:
556640
@as_async_generator_result(ValueError)

0 commit comments

Comments
 (0)