Skip to content

Commit a7a61bb

Browse files
committed
Replace recursive unrolling with manual unrolling
1 parent ed017fa commit a7a61bb

File tree

7 files changed

+267
-109
lines changed

7 files changed

+267
-109
lines changed

docs/src/developer_guide.md

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ CurrentModule = UnrolledUtilities
55
## How to Unroll
66

77
There are two general ways to implement loop unrolling in Julia—recursively
8-
splatting iterator contents and manually generating unrolled expressions. For
8+
splatting iterator contents and manually constructing unrolled expressions. For
99
example, a recursively unrolled version of the `foreach` function is
1010

1111
```julia
@@ -14,23 +14,42 @@ _unrolled_foreach(f) = nothing
1414
_unrolled_foreach(f, item, items...) = (f(item); _unrolled_foreach(f, items...))
1515
```
1616

17-
In contrast, a generatively unrolled implementation of this function looks like
17+
In contrast, a manually unrolled implementation of this function looks like
1818

1919
```julia
2020
unrolled_foreach(f, itr) = _unrolled_foreach(Val(length(itr)), f, itr)
2121
@generated _unrolled_foreach(::Val{N}, f, itr) where {N} =
2222
Expr(:block, (:(f(generic_getindex(itr, $n))) for n in 1:N)..., nothing)
2323
```
2424

25-
To switch between recursive and generative unrolling, this package defines the
26-
following function:
25+
Julia's compiler can only pass up to 32 values through function arguments
26+
without allocating heap memory, so recursive unrolling is not type-stable for
27+
iterators with lengths greater than 32. However, automatically generating
28+
functions often requires more time and memory resources during compilation than
29+
writing hard-coded functions. Recursive inlining adds overhead to compilation
30+
as well, but this is typically smaller than the overhead of generated functions
31+
for short iterators. To avoid sacrificing latency by using generated functions,
32+
several hard-coded methods can be added to the manually unrolled implementation:
2733

28-
```@docs
29-
rec_unroll
34+
```julia
35+
_unrolled_foreach(::Val{0}, f, itr) = nothing
36+
_unrolled_foreach(::Val{1}, f, itr) = (f(generic_getindex(itr, 1)); nothing)
37+
_unrolled_foreach(::Val{2}, f, itr) =
38+
(f(generic_getindex(itr, 1)); f(generic_getindex(itr, 2)); nothing)
39+
_unrolled_foreach(::Val{3}, f, itr) =
40+
(f(generic_getindex(itr, 1)); f(generic_getindex(itr, 2)); f(generic_getindex(itr, 3)); nothing)
3041
```
3142

32-
The default choice for `rec_unroll` is motivated by the benchmarks for
33-
[Generative vs. Recursive Unrolling](@ref).
43+
With this modification, manual unrolling does not exceed the compilation
44+
requirements of recursive unrolling across a wide range of use cases. Since it
45+
also avoids type instabilities for arbitrarily large iterators, a combination
46+
of hard-coded and generated functions with manual unrolling serves as the basis
47+
of all unrolled functions defined in this package. Similarly, the
48+
[`ntuple(f, ::Val{N})`](https://github.com/JuliaLang/julia/blob/v1.11.0/base/ntuple.jl)
49+
function in `Base` uses this strategy to implement loop unrolling.
50+
51+
For benchmarks that compare these two implementations, see
52+
[Manual vs. Recursive Unrolling](@ref).
3453

3554
## Interface API
3655

@@ -54,10 +73,8 @@ StaticSequence
5473

5574
To unroll over a statically sized iterator of some user-defined type `T`, follow
5675
these steps:
57-
- To enable recursive unrolling, add a method for `iterate(::T, [state])`
58-
- To enable generative unrolling, add a method for `getindex(::T, n)` (or for
59-
`generic_getindex(::T, n)` if `getindex` should not be defined for iterators
60-
of type `T`)
76+
- Add a method for `getindex(::T, n)`, or for `generic_getindex(::T, n)` if
77+
`getindex` should not be defined for iterators of type `T`
6178
- If every unrolled function that needs to construct an iterator when given an
6279
iterator of type `T` can return a `Tuple` instead, stop here
6380
- Otherwise, to return a non-`Tuple` iterator whenever it is efficient to do so,

src/UnrolledUtilities.jl

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ include("unrollable_iterator_interface.jl")
5353
# Analogue of the non-public Base._InitialValue for reduction and accumulation.
5454
struct NoInit end
5555

56+
@inline empty_reduction_value(::NoInit) =
57+
error("unrolled_reduce requires an init value for empty iterators")
58+
@inline empty_reduction_value(init) = init
59+
60+
@inline first_reduction_value(op, itr, ::NoInit) = generic_getindex(itr, 1)
61+
@inline first_reduction_value(op, itr, init) =
62+
op(init, generic_getindex(itr, 1))
63+
5664
# Analogue of ∘, but with only one function argument and guaranteed inlining.
5765
# Base's ∘ leads to type instabilities in unit tests on Julia 1.10 and 1.11.
5866
@inline (f1::F1, f2::F2) where {F1, F2} = x -> (@inline f1(f2(x)))
@@ -134,17 +142,17 @@ include("StaticBitVector.jl")
134142
unrolled_drop_into(inferred_output_type(itr), itr, val_N)
135143

136144
##
137-
## Functions unrolled using either recursion or generated expressions
145+
## Functions unrolled manually (hard-coded for N <= 3, auto-generated for N > 3)
138146
##
139147

140-
include("recursively_unrolled_functions.jl")
148+
include("hard_coded_unrolled_functions.jl")
141149
include("generatively_unrolled_functions.jl")
142150

143151
# The unrolled_map function could also be implemented in terms of ntuple, but
144152
# then it would be subject to the same recursion limit as ntuple. On Julia 1.10,
145153
# this leads to type instabilities in several unit tests for nested iterators.
146154
@inline unrolled_map_into_tuple(f::F, itr) where {F} =
147-
(rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr)
155+
_unrolled_map(Val(length(itr)), f, itr)
148156
@inline unrolled_map_into(output_type, f::F, itr) where {F} =
149157
constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itr))
150158
@inline unrolled_map(f::F, itr) where {F} =
@@ -154,32 +162,26 @@ include("generatively_unrolled_functions.jl")
154162

155163
@inline unrolled_any(itr) = unrolled_any(identity, itr)
156164
@inline unrolled_any(f::F, itr) where {F} =
157-
(rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr)
165+
_unrolled_any(Val(length(itr)), f, itr)
158166

159167
@inline unrolled_all(itr) = unrolled_all(identity, itr)
160168
@inline unrolled_all(f::F, itr) where {F} =
161-
(rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr)
169+
_unrolled_all(Val(length(itr)), f, itr)
162170

163171
@inline unrolled_foreach(f::F, itr) where {F} =
164-
(rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr)
172+
_unrolled_foreach(Val(length(itr)), f, itr)
165173
@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...))
166174

167175
@inline unrolled_reduce(op::O, itr, init) where {O} =
168-
isempty(itr) && init isa NoInit ?
169-
error("unrolled_reduce requires an init value for empty iterators") :
170-
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
176+
_unrolled_reduce(Val(length(itr)), op, itr, init)
171177
@inline unrolled_reduce(op::O, itr; init = NoInit()) where {O} =
172178
unrolled_reduce(op, itr, init)
173179

174180
@inline unrolled_mapreduce(f::F, op::O, itrs...; init = NoInit()) where {F, O} =
175181
unrolled_reduce(op, unrolled_map(f, itrs...), init)
176182

177183
@inline unrolled_accumulate_into_tuple(op::O, itr, init) where {O} =
178-
(rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)(
179-
op,
180-
itr,
181-
init,
182-
)
184+
_unrolled_accumulate(Val(length(itr)), op, itr, init)
183185
@inline unrolled_accumulate_into(output_type, op::O, itr, init) where {O} =
184186
constructor_from_tuple(output_type)(
185187
unrolled_accumulate_into_tuple(op, itr, init),
@@ -209,13 +211,7 @@ include("generatively_unrolled_functions.jl")
209211
itr,
210212
itrs...,
211213
) where {F, I, E} =
212-
(rec_unroll(itr) ? rec_unrolled_ifelse : gen_unrolled_ifelse)(
213-
f,
214-
get_if,
215-
get_else,
216-
itr,
217-
itrs...,
218-
)
214+
_unrolled_ifelse(Val(length(itr)), f, get_if, get_else, itr, itrs...)
219215

220216
##
221217
## Unrolled functions without any analogues in Base
Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,45 @@
1-
@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = quote
1+
@generated _unrolled_map(::Val{N}, f, itr) where {N} = quote
22
@inline
33
return Base.Cartesian.@ntuple $N n -> f(generic_getindex(itr, n))
44
end
5-
@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)
65

7-
@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = quote
6+
@generated _unrolled_any(::Val{N}, f, itr) where {N} = quote
87
@inline
98
return Base.Cartesian.@nany $N n -> f(generic_getindex(itr, n))
109
end
11-
@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)
1210

13-
@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = quote
11+
@generated _unrolled_all(::Val{N}, f, itr) where {N} = quote
1412
@inline
1513
return Base.Cartesian.@nall $N n -> f(generic_getindex(itr, n))
1614
end
17-
@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)
1815

19-
@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = quote
16+
@generated _unrolled_foreach(::Val{N}, f, itr) where {N} = quote
2017
@inline
2118
Base.Cartesian.@nexprs $N n -> f(generic_getindex(itr, n))
2219
return nothing
2320
end
24-
@inline gen_unrolled_foreach(f, itr) =
25-
_gen_unrolled_foreach(Val(length(itr)), f, itr)
2621

27-
@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = quote
22+
@generated _unrolled_reduce(::Val{N}, op, itr, init) where {N} = quote
2823
@inline
29-
$N == 0 && return init
30-
first_itr_item = generic_getindex(itr, 1)
31-
value_1 = init isa NoInit ? first_itr_item : op(init, first_itr_item)
24+
# $N == 0 && return empty_reduction_value(init)
25+
# The N == 0 case is handled separately in a hard-coded method.
26+
value_1 = first_reduction_value(op, itr, init)
3227
Base.Cartesian.@nexprs $(N - 1) n ->
3328
(value_{n + 1} = op(value_n, generic_getindex(itr, n + 1)))
3429
return $(Symbol(:value_, N))
3530
end
36-
@inline gen_unrolled_reduce(op, itr, init) =
37-
_gen_unrolled_reduce(Val(length(itr)), op, itr, init)
3831

39-
@generated _gen_unrolled_accumulate(::Val{N}, op, itr, init) where {N} = quote
32+
@generated _unrolled_accumulate(::Val{N}, op, itr, init) where {N} = quote
4033
@inline
41-
$N == 0 && return ()
42-
first_itr_item = generic_getindex(itr, 1)
43-
value_1 = init isa NoInit ? first_itr_item : op(init, first_itr_item)
34+
# $N == 0 && return ()
35+
# The N == 0 case is handled separately in a hard-coded method.
36+
value_1 = first_reduction_value(op, itr, init)
4437
Base.Cartesian.@nexprs $(N - 1) n ->
4538
(value_{n + 1} = op(value_n, generic_getindex(itr, n + 1)))
4639
return Base.Cartesian.@ntuple $N n -> value_n
4740
end
48-
@inline gen_unrolled_accumulate(op, itr, init) =
49-
_gen_unrolled_accumulate(Val(length(itr)), op, itr, init)
5041

51-
@generated _gen_unrolled_ifelse(::Val{N}, f, get_if, get_else, itr) where {N} =
42+
@generated _unrolled_ifelse(::Val{N}, f, get_if, get_else, itr) where {N} =
5243
quote
5344
@inline
5445
Base.Cartesian.@nexprs $N n -> begin
@@ -57,10 +48,8 @@ end
5748
end
5849
return get_else()
5950
end
60-
@inline gen_unrolled_ifelse(f, get_if, get_else, itr) =
61-
_gen_unrolled_ifelse(Val(length(itr)), f, get_if, get_else, itr)
6251

63-
@generated _gen_unrolled_ifelse2(
52+
@generated _unrolled_ifelse(
6453
::Val{N},
6554
f,
6655
get_if,
@@ -74,5 +63,3 @@ end
7463
end
7564
return get_else()
7665
end
77-
@inline gen_unrolled_ifelse(f, get_if, get_else, itr1, itr2) =
78-
_gen_unrolled_ifelse2(Val(length(itr1)), f, get_if, get_else, itr1, itr2)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
@inline _unrolled_map(::Val{0}, f, itr) = ()
2+
@inline _unrolled_map(::Val{1}, f, itr) = (f(generic_getindex(itr, 1)),)
3+
@inline _unrolled_map(::Val{2}, f, itr) =
4+
(f(generic_getindex(itr, 1)), f(generic_getindex(itr, 2)))
5+
@inline _unrolled_map(::Val{3}, f, itr) = (
6+
f(generic_getindex(itr, 1)),
7+
f(generic_getindex(itr, 2)),
8+
f(generic_getindex(itr, 3)),
9+
)
10+
11+
@inline _unrolled_any(::Val{0}, f, itr) = false
12+
@inline _unrolled_any(::Val{1}, f, itr) = f(generic_getindex(itr, 1))
13+
@inline _unrolled_any(::Val{2}, f, itr) =
14+
f(generic_getindex(itr, 1)) || f(generic_getindex(itr, 2))
15+
@inline _unrolled_any(::Val{3}, f, itr) =
16+
f(generic_getindex(itr, 1)) ||
17+
f(generic_getindex(itr, 2)) ||
18+
f(generic_getindex(itr, 3))
19+
20+
@inline _unrolled_all(::Val{0}, f, itr) = true
21+
@inline _unrolled_all(::Val{1}, f, itr) = f(generic_getindex(itr, 1))
22+
@inline _unrolled_all(::Val{2}, f, itr) =
23+
f(generic_getindex(itr, 1)) && f(generic_getindex(itr, 2))
24+
@inline _unrolled_all(::Val{3}, f, itr) =
25+
f(generic_getindex(itr, 1)) &&
26+
f(generic_getindex(itr, 2)) &&
27+
f(generic_getindex(itr, 3))
28+
29+
@inline _unrolled_foreach(::Val{0}, f, itr) = nothing
30+
@inline function _unrolled_foreach(::Val{1}, f, itr)
31+
f(generic_getindex(itr, 1))
32+
return nothing
33+
end
34+
@inline function _unrolled_foreach(::Val{2}, f, itr)
35+
f(generic_getindex(itr, 1))
36+
f(generic_getindex(itr, 2))
37+
return nothing
38+
end
39+
@inline function _unrolled_foreach(::Val{3}, f, itr)
40+
f(generic_getindex(itr, 1))
41+
f(generic_getindex(itr, 2))
42+
f(generic_getindex(itr, 3))
43+
return nothing
44+
end
45+
46+
@inline _unrolled_reduce(::Val{0}, op, itr, init) = empty_reduction_value(init)
47+
@inline _unrolled_reduce(::Val{1}, op, itr, init) =
48+
first_reduction_value(op, itr, init)
49+
@inline _unrolled_reduce(::Val{2}, op, itr, init) =
50+
op(first_reduction_value(op, itr, init), generic_getindex(itr, 2))
51+
@inline _unrolled_reduce(::Val{3}, op, itr, init) = op(
52+
op(first_reduction_value(op, itr, init), generic_getindex(itr, 2)),
53+
generic_getindex(itr, 3),
54+
)
55+
56+
@inline _unrolled_accumulate(::Val{0}, op, itr, init) = ()
57+
@inline function _unrolled_accumulate(::Val{1}, op, itr, init)
58+
value_1 = first_reduction_value(op, itr, init)
59+
return (value_1,)
60+
end
61+
@inline function _unrolled_accumulate(::Val{2}, op, itr, init)
62+
value_1 = first_reduction_value(op, itr, init)
63+
value_2 = op(value_1, generic_getindex(itr, 2))
64+
return (value_1, value_2)
65+
end
66+
@inline function _unrolled_accumulate(::Val{3}, op, itr, init)
67+
value_1 = first_reduction_value(op, itr, init)
68+
value_2 = op(value_1, generic_getindex(itr, 2))
69+
value_3 = op(value_2, generic_getindex(itr, 3))
70+
return (value_1, value_2, value_3)
71+
end
72+
73+
@inline _unrolled_ifelse(::Val{0}, f, get_if, get_else, itr) = get_else()
74+
@inline function _unrolled_ifelse(::Val{1}, f, get_if, get_else, itr)
75+
item_1 = generic_getindex(itr, 1)
76+
f(item_1) && return get_if(item_1)
77+
return get_else()
78+
end
79+
@inline function _unrolled_ifelse(::Val{2}, f, get_if, get_else, itr)
80+
item_1 = generic_getindex(itr, 1)
81+
f(item_1) && return get_if(item_1)
82+
item_2 = generic_getindex(itr, 2)
83+
f(item_2) && return get_if(item_2)
84+
return get_else()
85+
end
86+
@inline function _unrolled_ifelse(::Val{3}, f, get_if, get_else, itr)
87+
item_1 = generic_getindex(itr, 1)
88+
f(item_1) && return get_if(item_1)
89+
item_2 = generic_getindex(itr, 2)
90+
f(item_2) && return get_if(item_2)
91+
item_3 = generic_getindex(itr, 3)
92+
f(item_3) && return get_if(item_3)
93+
return get_else()
94+
end
95+
96+
@inline _unrolled_ifelse(::Val{0}, f, get_if, get_else, itr1, itr2) = get_else()
97+
@inline function _unrolled_ifelse(::Val{1}, f, get_if, get_else, itr1, itr2)
98+
f(generic_getindex(itr1, 1)) && return get_if(generic_getindex(itr2, 1))
99+
return get_else()
100+
end
101+
@inline function _unrolled_ifelse(::Val{2}, f, get_if, get_else, itr1, itr2)
102+
f(generic_getindex(itr1, 1)) && return get_if(generic_getindex(itr2, 1))
103+
f(generic_getindex(itr1, 2)) && return get_if(generic_getindex(itr2, 2))
104+
return get_else()
105+
end
106+
@inline function _unrolled_ifelse(::Val{3}, f, get_if, get_else, itr1, itr2)
107+
f(generic_getindex(itr1, 1)) && return get_if(generic_getindex(itr2, 1))
108+
f(generic_getindex(itr1, 2)) && return get_if(generic_getindex(itr2, 2))
109+
f(generic_getindex(itr1, 3)) && return get_if(generic_getindex(itr2, 3))
110+
return get_else()
111+
end

src/unrollable_iterator_interface.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
"""
2-
rec_unroll(itr)
3-
4-
Whether to use recursive loop unrolling instead of generative loop unrolling for
5-
the iterator `itr`. Recursive unrolling can lead to suboptimal LLVM code for
6-
iterators of more than 32 items, but it is typically faster than generative
7-
unrolling for short iterators. By default, recursive unrolling is used for
8-
iterators up to length 2, and generative unrolling is used for longer iterators.
9-
"""
10-
@inline rec_unroll(itr) = length(itr) <= 2
11-
121
"""
132
generic_getindex(itr, n)
143
Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# TODO: Replace all of these with manually unrolled functions, which should be
2-
# faster to compile. That is the pattern used in Base for ntuple, map, etc.
1+
using UnrolledUtilities: NoInit, generic_getindex, unrolled_drop
32

43
@inline _rec_unrolled_map(f) = ()
54
@inline _rec_unrolled_map(f, item, items...) =
@@ -53,11 +52,4 @@
5352
f(item1) ? get_if(item2) :
5453
_rec_unrolled_ifelse2(f, get_if, get_else, items...)
5554
@inline rec_unrolled_ifelse(f, get_if, get_else, itr1, itr2) =
56-
_rec_unrolled_ifelse2(f, get_if, get_else, _unrolled_zip(itr1, itr2)...)
57-
# Using zip here triggers the recursion limit for one unit test on Julia 1.10.
58-
59-
@inline _unrolled_zip(itr1, itr2) =
60-
ntuple(Val(min(length(itr1), length(itr2)))) do n
61-
@inline
62-
(generic_getindex(itr1, n), generic_getindex(itr2, n))
63-
end
55+
_rec_unrolled_ifelse2(f, get_if, get_else, zip(itr1, itr2)...)

0 commit comments

Comments
 (0)