A training sampler that adds scheduled sampling

sampler_scheduled_embedding_training(
  sampling_probability,
  embedding_fn = NULL,
  time_major = FALSE,
  seed = NULL,
  scheduling_seed = NULL
)

Arguments

sampling_probability

A float32 0-D or 1-D tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.

embedding_fn

A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.

time_major

bool. Whether the tensors in inputs are time major. If `FALSE` (default), they are assumed to be batch major.

seed

The sampling seed.

scheduling_seed

The schedule decision rule sampling seed.

Value

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.