Calculates the full beams for `TensorArray`s.

gather_tree_from_array(t, parent_ids, sequence_length)

Arguments

t

A stacked `TensorArray` of size `max_time` that contains `Tensor`s of shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where `s` is the depth shape.

parent_ids

The parent ids of shape `[max_time, batch_size, beam_width]`.

sequence_length

The sequence length of shape `[batch_size, beam_width]`.

Value

A `Tensor` which is a stacked `TensorArray` of the same size and type as `t` and where beams are sorted in each `Tensor` according to `parent_ids`.