R/seq2seq.R
sampler_scheduled_embedding_training.Rd
A training sampler that adds scheduled sampling
sampler_scheduled_embedding_training( sampling_probability, embedding_fn = NULL, time_major = FALSE, seed = NULL, scheduling_seed = NULL )
sampling_probability | A float32 0-D or 1-D tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs. |
---|---|
embedding_fn | A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. |
time_major | bool. Whether the tensors in inputs are time major. If `FALSE` (default), they are assumed to be batch major. |
seed | The sampling seed. |
scheduling_seed | The schedule decision rule sampling seed. |
Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.