Attention Wrapper

attention_wrapper(
  object,
  cell,
  attention_mechanism,
  attention_layer_size = NULL,
  alignment_history = FALSE,
  cell_input_fn = NULL,
  output_attention = TRUE,
  initial_cell_state = NULL,
  name = NULL,
  attention_layer = NULL,
  attention_fn = NULL,
  ...
)

Arguments

object

Model or layer object

cell

An instance of RNNCell.

attention_mechanism

A list of AttentionMechanism instances or a single instance.

attention_layer_size

A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If `NULL` (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If attention_layer is set, this must be `NULL`. If attention_fn is set, it must guaranteed that the outputs of `attention_fn` also meet the above requirements.

alignment_history

Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major TensorArray on which you must call stack()).

cell_input_fn

(optional) A callable. The default is: lambda inputs, attention: tf$concat(list(inputs, attention), -1).

output_attention

Python bool. If True (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If FALSE, the output at each time step is the output of cell. This is the behavior of Bhadanau-style attention mechanisms. In both cases, the attention tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output.

initial_cell_state

The initial state value to use for the cell when the user calls get_initial_state(). Note that if this value is provided now, and the user uses a batch_size argument of get_initial_state which does not match the batch size of initial_cell_state, proper behavior is not guaranteed.

name

Name to use when creating ops.

attention_layer

A list of tf$keras$layers$Layer instances or a single tf$keras$layers$Layer instance taking the context and cell output as inputs to generate attention at each time step. If `NULL` (default), use the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be `NULL`.

attention_fn

An optional callable function that allows users to provide their own customized attention function, which takes input (attention_mechanism, cell_output, attention_state, attention_layer) and outputs (attention, alignments, next_attention_state). If provided, the attention_layer_size should be the size of the outputs of attention_fn.

...

Other keyword arguments to pass

Value

None

Note

If you are using the `decoder_beam_search` with a cell wrapped in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via `tile_batch` (NOT `tf$tile`). - The `batch_size` argument passed to the `get_initial_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `get_initial_state` above contains a `cell_state` value containing properly tiled final state from the encoder.