Skip to content

Commit 7373be9

Browse files
Merge pull request #25 from CliMA/dy/manual_unrolling
Replace recursive unrolling with hard-coded unrolling
2 parents ed017fa + 311b554 commit 7373be9

8 files changed

+302
-162
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "UnrolledUtilities"
22
uuid = "0fe1646c-419e-43be-ac14-22321958931b"
33
authors = ["CliMA Contributors <[email protected]>"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[compat]
77
julia = "1.9"

docs/src/developer_guide.md

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ CurrentModule = UnrolledUtilities
55
## How to Unroll
66

77
There are two general ways to implement loop unrolling in Julia—recursively
8-
splatting iterator contents and manually generating unrolled expressions. For
8+
splatting iterator contents and manually constructing unrolled expressions. For
99
example, a recursively unrolled version of the `foreach` function is
1010

1111
```julia
@@ -14,23 +14,42 @@ _unrolled_foreach(f) = nothing
1414
_unrolled_foreach(f, item, items...) = (f(item); _unrolled_foreach(f, items...))
1515
```
1616

17-
In contrast, a generatively unrolled implementation of this function looks like
17+
In contrast, a manually unrolled implementation of this function looks like
1818

1919
```julia
2020
unrolled_foreach(f, itr) = _unrolled_foreach(Val(length(itr)), f, itr)
2121
@generated _unrolled_foreach(::Val{N}, f, itr) where {N} =
2222
Expr(:block, (:(f(generic_getindex(itr, $n))) for n in 1:N)..., nothing)
2323
```
2424

25-
To switch between recursive and generative unrolling, this package defines the
26-
following function:
25+
Julia's compiler can only pass up to 32 values through function arguments
26+
without allocating heap memory, so recursive unrolling is not type-stable for
27+
iterators with lengths greater than 32. However, automatically generating
28+
functions often requires more time and memory resources during compilation than
29+
writing hard-coded functions. Recursive inlining adds overhead to compilation
30+
as well, but this is typically smaller than the overhead of generated functions
31+
for short iterators. To avoid sacrificing latency by using generated functions,
32+
several hard-coded methods can be added to the manually unrolled implementation:
2733

28-
```@docs
29-
rec_unroll
34+
```julia
35+
_unrolled_foreach(::Val{0}, f, itr) = nothing
36+
_unrolled_foreach(::Val{1}, f, itr) = (f(generic_getindex(itr, 1)); nothing)
37+
_unrolled_foreach(::Val{2}, f, itr) =
38+
(f(generic_getindex(itr, 1)); f(generic_getindex(itr, 2)); nothing)
39+
_unrolled_foreach(::Val{3}, f, itr) =
40+
(f(generic_getindex(itr, 1)); f(generic_getindex(itr, 2)); f(generic_getindex(itr, 3)); nothing)
3041
```
3142

32-
The default choice for `rec_unroll` is motivated by the benchmarks for
33-
[Generative vs. Recursive Unrolling](@ref).
43+
With this modification, manual unrolling does not exceed the compilation
44+
requirements of recursive unrolling across a wide range of use cases. Since it
45+
also avoids type instabilities for arbitrarily large iterators, a combination
46+
of hard-coded and generated functions with manual unrolling serves as the basis
47+
of all unrolled functions defined in this package. Similarly, the
48+
[`ntuple(f, ::Val{N})`](https://github.com/JuliaLang/julia/blob/v1.11.0/base/ntuple.jl)
49+
function in `Base` uses this strategy to implement loop unrolling.
50+
51+
For benchmarks that compare these two implementations, see
52+
[Manual vs. Recursive Unrolling](@ref).
3453

3554
## Interface API
3655

@@ -54,10 +73,8 @@ StaticSequence
5473

5574
To unroll over a statically sized iterator of some user-defined type `T`, follow
5675
these steps:
57-
- To enable recursive unrolling, add a method for `iterate(::T, [state])`
58-
- To enable generative unrolling, add a method for `getindex(::T, n)` (or for
59-
`generic_getindex(::T, n)` if `getindex` should not be defined for iterators
60-
of type `T`)
76+
- Add a method for `getindex(::T, n)`, or for `generic_getindex(::T, n)` if
77+
`getindex` should not be defined for iterators of type `T`
6178
- If every unrolled function that needs to construct an iterator when given an
6279
iterator of type `T` can return a `Tuple` instead, stop here
6380
- Otherwise, to return a non-`Tuple` iterator whenever it is efficient to do so,

src/UnrolledUtilities.jl

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ include("unrollable_iterator_interface.jl")
5353
# Analogue of the non-public Base._InitialValue for reduction and accumulation.
5454
struct NoInit end
5555

56+
@inline empty_reduction_value(::NoInit) =
57+
error("unrolled_reduce requires an init value for empty iterators")
58+
@inline empty_reduction_value(init) = init
59+
60+
@inline first_reduction_value(op, itr, ::NoInit) = generic_getindex(itr, 1)
61+
@inline first_reduction_value(op, itr, init) =
62+
op(init, generic_getindex(itr, 1))
63+
5664
# Analogue of ∘, but with only one function argument and guaranteed inlining.
5765
# Base's ∘ leads to type instabilities in unit tests on Julia 1.10 and 1.11.
5866
@inline (f1::F1, f2::F2) where {F1, F2} = x -> (@inline f1(f2(x)))
@@ -134,17 +142,16 @@ include("StaticBitVector.jl")
134142
unrolled_drop_into(inferred_output_type(itr), itr, val_N)
135143

136144
##
137-
## Functions unrolled using either recursion or generated expressions
145+
## Functions unrolled using either hard-coded or generated expressions
138146
##
139147

140-
include("recursively_unrolled_functions.jl")
141-
include("generatively_unrolled_functions.jl")
148+
include("manually_unrolled_functions.jl")
142149

143150
# The unrolled_map function could also be implemented in terms of ntuple, but
144151
# then it would be subject to the same recursion limit as ntuple. On Julia 1.10,
145152
# this leads to type instabilities in several unit tests for nested iterators.
146153
@inline unrolled_map_into_tuple(f::F, itr) where {F} =
147-
(rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr)
154+
_unrolled_map(Val(length(itr)), f, itr)
148155
@inline unrolled_map_into(output_type, f::F, itr) where {F} =
149156
constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itr))
150157
@inline unrolled_map(f::F, itr) where {F} =
@@ -154,32 +161,26 @@ include("generatively_unrolled_functions.jl")
154161

155162
@inline unrolled_any(itr) = unrolled_any(identity, itr)
156163
@inline unrolled_any(f::F, itr) where {F} =
157-
(rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr)
164+
_unrolled_any(Val(length(itr)), f, itr)
158165

159166
@inline unrolled_all(itr) = unrolled_all(identity, itr)
160167
@inline unrolled_all(f::F, itr) where {F} =
161-
(rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr)
168+
_unrolled_all(Val(length(itr)), f, itr)
162169

163170
@inline unrolled_foreach(f::F, itr) where {F} =
164-
(rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr)
171+
_unrolled_foreach(Val(length(itr)), f, itr)
165172
@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...))
166173

167174
@inline unrolled_reduce(op::O, itr, init) where {O} =
168-
isempty(itr) && init isa NoInit ?
169-
error("unrolled_reduce requires an init value for empty iterators") :
170-
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
175+
_unrolled_reduce(Val(length(itr)), op, itr, init)
171176
@inline unrolled_reduce(op::O, itr; init = NoInit()) where {O} =
172177
unrolled_reduce(op, itr, init)
173178

174179
@inline unrolled_mapreduce(f::F, op::O, itrs...; init = NoInit()) where {F, O} =
175180
unrolled_reduce(op, unrolled_map(f, itrs...), init)
176181

177182
@inline unrolled_accumulate_into_tuple(op::O, itr, init) where {O} =
178-
(rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)(
179-
op,
180-
itr,
181-
init,
182-
)
183+
_unrolled_accumulate(Val(length(itr)), op, itr, init)
183184
@inline unrolled_accumulate_into(output_type, op::O, itr, init) where {O} =
184185
constructor_from_tuple(output_type)(
185186
unrolled_accumulate_into_tuple(op, itr, init),
@@ -209,13 +210,7 @@ include("generatively_unrolled_functions.jl")
209210
itr,
210211
itrs...,
211212
) where {F, I, E} =
212-
(rec_unroll(itr) ? rec_unrolled_ifelse : gen_unrolled_ifelse)(
213-
f,
214-
get_if,
215-
get_else,
216-
itr,
217-
itrs...,
218-
)
213+
_unrolled_ifelse(Val(length(itr)), f, get_if, get_else, itr, itrs...)
219214

220215
##
221216
## Unrolled functions without any analogues in Base

src/generatively_unrolled_functions.jl

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)