BeamSearch sampling decoder
decoder_beam_search( object, cell, beam_width, embedding_fn = NULL, output_layer = NULL, length_penalty_weight = 0, coverage_penalty_weight = 0, reorder_tensor_arrays = TRUE, ... )
object | Model or layer object |
---|---|
cell | An RNNCell instance. |
beam_width | integer, the number of beams. |
embedding_fn | A callable that takes a vector tensor of ids (argmax ids). |
output_layer | (Optional) An instance of tf.keras.layers.Layer, i.e., tf$keras$layers$Dense. Optional layer to apply to the RNN output prior to storing the result or sampling. |
length_penalty_weight | Float weight to penalize length. Disabled with 0.0. |
coverage_penalty_weight | Float weight to penalize the coverage of source sentence. Disabled with 0.0. |
reorder_tensor_arrays | If `TRUE`, TensorArrays' elements within the cell state will be reordered according to the beam search path. If the TensorArray can be reordered, the stacked form will be returned. Otherwise, the TensorArray will be returned as is. Set this flag to False if the cell state contains TensorArrays that are not amenable to reordering. |
... | A list, other keyword arguments for initialization. |
None
If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via `tile_batch()` (NOT `tf$tile`). - The `batch_size` argument passed to the `get_initial_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `get_initial_state` above contains a `cell_state` value containing properly tiled final state from the encoder.