Skip to content

Commit ed017fa

Browse files
Merge pull request #23 from CliMA/dy/ntuple_bugfix
Fix recursion instability in unrolled_product
2 parents 76dec6a + bfa0727 commit ed017fa

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "UnrolledUtilities"
22
uuid = "0fe1646c-419e-43be-ac14-22321958931b"
33
authors = ["CliMA Contributors <[email protected]>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[compat]
77
julia = "1.9"

src/UnrolledUtilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,10 @@ include("generatively_unrolled_functions.jl")
399399
@inline unrolled_product(itrs...) =
400400
ntuple(Val(unrolled_prod(length, itrs))) do n
401401
@inline
402+
Base.@assume_effects :foldable
402403
items = ntuple(Val(length(itrs))) do itr_index
403404
@inline
405+
Base.@assume_effects :foldable
404406
cur_length = length(itrs[itr_index])
405407
prev_length = unrolled_prod(length, itrs[1:(itr_index - 1)])
406408
item_index = (n - 1) ÷ prev_length % cur_length + 1

test/test_and_analyze.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,16 @@ macro test_unrolled(
289289
isempty(JET.get_reports(@report_opt reference_func($(args...))))
290290
$(esc(skip_type_stability_test)) || @test_opt unrolled_func($(args...))
291291

292+
unrolled_code = code_instance(unrolled_func, $(args...))
293+
reference_code = code_instance(reference_func, $(args...))
294+
292295
# Test for constant propagation.
293296
is_unrolled_const =
294-
isdefined(code_instance(unrolled_func, $(args...)), :rettype_const)
297+
isdefined(unrolled_code, :rettype_const) &&
298+
isbits(unrolled_code.rettype_const)
295299
is_reference_const =
296-
isdefined(code_instance(reference_func, $(args...)), :rettype_const)
300+
isdefined(reference_code, :rettype_const) &&
301+
isbits(reference_code.rettype_const)
297302

298303
buffer = IOBuffer()
299304
args_type = Tuple{map(typeof, ($(args...),))...}
@@ -761,12 +766,15 @@ for itr in (
761766
str,
762767
)
763768

769+
# TODO: Testing with coverage triggers allocations for unrolled_product!
764770
if length(itr) <= 32
765771
@test_unrolled(
766772
(itr,),
767773
unrolled_product(itr, itr),
768774
Tuple(Iterators.product(itr, itr)),
769775
str,
776+
"fast_mode" in ARGS,
777+
"fast_mode" in ARGS,
770778
)
771779
end # This can take several minutes to compile when the length is 128.
772780
if length(itr) <= 8
@@ -775,6 +783,8 @@ for itr in (
775783
unrolled_product(itr, itr, itr),
776784
Tuple(Iterators.product(itr, itr, itr)),
777785
str,
786+
"fast_mode" in ARGS,
787+
"fast_mode" in ARGS,
778788
)
779789
end # This can take several minutes to compile when the length is 32.
780790

0 commit comments

Comments
 (0)