diff --git a/LSTM.lua b/LSTM.lua index 8367adff..6be280c6 100644 --- a/LSTM.lua +++ b/LSTM.lua @@ -143,7 +143,7 @@ function layer:updateOutput(input) end end - local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) + local bias_expand = self.bias:view(1, 4 * H):expand(N * T, 4 * H) local Wx = self.weight[{{1, D}}] local Wh = self.weight[{{D + 1, D + H}}] @@ -152,12 +152,11 @@ function layer:updateOutput(input) c:resize(N, T, H):zero() local prev_h, prev_c = h0, c0 self.gates:resize(N, T, 4 * H):zero() + self.gates:view(N * T, 4 * H):addmm(bias_expand, x:view(N * T, D), Wx) for t = 1, T do - local cur_x = x[{{}, t}] local next_h = h[{{}, t}] local next_c = c[{{}, t}] local cur_gates = self.gates[{{}, t}] - cur_gates:addmm(bias_expand, cur_x, Wx) cur_gates:addmm(prev_h, Wh) cur_gates[{{}, {1, 3 * H}}]:sigmoid() cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh()