GRUCell

scalation.modeling.autograd.GRUCell
See theGRUCell companion object
class GRUCell(inputSize: Int, hiddenSize: Int)(using ops: AutogradOps)

The GRUCell class supports a gated recurrent unit cell: r_t = sigmoid(W_ir * x + b_ir + W_hr * h_{t-1} + b_hr) z_t = sigmoid(W_iz * x + b_iz + W_hz * h_{t-1} + b_hz) n_t = tanh(W_in * x + b_in + r_t ⊙ (W_hn * h_{t-1} + b_hn)) h_t = (1 - z_t) ⊙ n_t + z_t ⊙ h_{t-1} This class defines the parameters and forward computation for a GRU cell.

Value parameters

hiddenSize

number of hidden units

inputSize

number of input features

Attributes

See also
Companion
object
Graph
Supertypes
class SeqModule
class BaseModule
class Object
trait Matchable
class Any

Members list

Value members

Concrete methods

override def forward(inputs: IndexedSeq[Variabl]): IndexedSeq[Variabl]

Perform the forward pass for the GRU cell. Computes the next hidden state based on the input and the previous hidden state.

Perform the forward pass for the GRU cell. Computes the next hidden state based on the input and the previous hidden state.

Value parameters

inputs

an indexed sequence containing: - input: the input tensor at the current time step - hPrev: the hidden state tensor from the previous time step

Attributes

Returns

an indexed sequence containing the next hidden state tensor

Throws
IllegalArgumentException

if the number of inputs is not exactly 2

Definition Classes
override def numTrackingStates: Int

Return the number of tracking states for the GRU cell. For GRU, this is always 1.

Return the number of tracking states for the GRU cell. For GRU, this is always 1.

Attributes

Definition Classes
RNNCellBase
override def parameters: IndexedSeq[Variabl]

Return the parameters of the GRU cell.

Return the parameters of the GRU cell.

Attributes

Returns

an indexed sequence of Variabl objects representing the parameters

Definition Classes
RNNCellBase -> BaseModule

Inherited methods

def apply(inputs: IndexedSeq[Variabl]): IndexedSeq[Variabl]

Alias for forward, allows calling the module as a function: module(xs).

Alias for forward, allows calling the module as a function: module(xs).

Attributes

Inherited from:
SeqModule
def eval(): Unit

Set the module to evaluation mode (and all submodules recursively).

Set the module to evaluation mode (and all submodules recursively).

Attributes

Inherited from:
BaseModule
def gradients: IndexedSeq[TensorD]

Return the gradients of all parameters.

Return the gradients of all parameters.

Attributes

Inherited from:
BaseModule
def initialTrackingStates(batchSize: Int): IndexedSeq[Variabl]

Create a batch of zero-initialized tracking states. You pass in the batch size to get properly shaped tensors: (batchSize, hiddenSize, 1)

Create a batch of zero-initialized tracking states. You pass in the batch size to get properly shaped tensors: (batchSize, hiddenSize, 1)

Attributes

Inherited from:
RNNCellBase (hidden)
def setParameters(newParams: IndexedSeq[Variabl]): Unit

Replace the current parameters with new ones. Useful for weight updates, loading saved models, etc.

Replace the current parameters with new ones. Useful for weight updates, loading saved models, etc.

Value parameters

newParams

The new parameter list to assign

Attributes

Inherited from:
BaseModule
def train(mode: Boolean = ...): Unit

Set the module to training mode (and all submodules recursively).

Set the module to training mode (and all submodules recursively).

Attributes

Inherited from:
BaseModule
def zeroGrad()(using ops: AutogradOps): Unit

Zero out all gradients (in-place).

Zero out all gradients (in-place).

Attributes

Inherited from:
BaseModule

Concrete fields

val W_hn: Variabl

Weight matrix for the hidden-to-hidden connection in the new gate.

Weight matrix for the hidden-to-hidden connection in the new gate.

Attributes

val W_hr: Variabl

Weight matrix for the hidden-to-hidden connection in the reset gate.

Weight matrix for the hidden-to-hidden connection in the reset gate.

Attributes

val W_hz: Variabl

Weight matrix for the hidden-to-hidden connection in the update gate.

Weight matrix for the hidden-to-hidden connection in the update gate.

Attributes

val W_in: Variabl

Weight matrix for the input-to-hidden connection in the new gate.

Weight matrix for the input-to-hidden connection in the new gate.

Attributes

val W_ir: Variabl

Weight matrix for the input-to-hidden connection in the reset gate.

Weight matrix for the input-to-hidden connection in the reset gate.

Attributes

val W_iz: Variabl

Weight matrix for the input-to-hidden connection in the update gate.

Weight matrix for the input-to-hidden connection in the update gate.

Attributes

val b_hn: Variabl

Bias for the hidden-to-hidden connection in the new gate.

Bias for the hidden-to-hidden connection in the new gate.

Attributes

val b_hr: Variabl

Bias for the hidden-to-hidden connection in the reset gate.

Bias for the hidden-to-hidden connection in the reset gate.

Attributes

val b_hz: Variabl

Bias for the hidden-to-hidden connection in the update gate.

Bias for the hidden-to-hidden connection in the update gate.

Attributes

val b_in: Variabl

Bias for the input-to-hidden connection in the new gate.

Bias for the input-to-hidden connection in the new gate.

Attributes

val b_ir: Variabl

Bias for the input-to-hidden connection in the reset gate.

Bias for the input-to-hidden connection in the reset gate.

Attributes

val b_iz: Variabl

Bias for the input-to-hidden connection in the update gate.

Bias for the input-to-hidden connection in the update gate.

Attributes

Inherited fields

var inTrainingMode: Boolean

Flag to control training or evaluation behavior.

Flag to control training or evaluation behavior.

Attributes

Inherited from:
BaseModule
lazy val subModules: IndexedSeq[BaseModule]

Automatically detect submodules (other BaseModules) within this module.

Automatically detect submodules (other BaseModules) within this module.

Attributes

Inherited from:
BaseModule