You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/ode_solve.jl
+14-29Lines changed: 14 additions & 29 deletions
Original file line number
Diff line number
Diff line change
@@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard
15
15
## Positional Arguments
16
16
17
17
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
18
-
`Flux.Chain` will be converted to `Lux` using `Lux.transform`.
18
+
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
19
19
* `opt`: The optimizer to train the neural network.
20
20
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
21
21
which thus uses the random initialization provided by the neural network library.
@@ -27,11 +27,10 @@ of the physics-informed neural network which is used as a solver for a standard
27
27
the PDE operators. The reverse mode of the loss function is always
28
28
automatic differentiation (via Zygote), this is only for the derivative
29
29
in the loss function (the derivative with respect to time).
30
-
* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which
31
-
means the application of the neural network is done at individual time points one
32
-
at a time. `batch>0` means the neural network is applied at a row vector of values
33
-
`t` simultaneously, i.e. it's the batch size for the neural network evaluations.
34
-
This requires a neural network compatible with batched data.
30
+
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
31
+
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
32
+
`false` means which means the application of the neural network is done at individual time points one at a time.
33
+
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
35
34
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
36
35
* `strategy`: The training strategy used to choose the points for the evaluations.
37
36
Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no
@@ -88,8 +87,8 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
(param_estim &&isnothing(additional_loss)) &&throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true)."))
0 commit comments