A sampler for use during inference.

sampler_greedy_embedding(embedding_fn = NULL)

Arguments

embedding_fn

A optional callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input. Default to use tf$nn$embedding_lookup.

Value

None

Details

Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.