Inference Sampler
sampler_inference( sample_fn, sample_shape, sample_dtype = tf$int32, end_fn, next_inputs_fn = NULL, ... )
sample_fn | A callable that takes outputs and emits tensor sample_ids. |
---|---|
sample_shape | Either a list of integers, or a 1-D Tensor of type int32, the shape of the each sample in the batch returned by sample_fn. |
sample_dtype | the dtype of the sample returned by sample_fn. |
end_fn | A callable that takes sample_ids and emits a bool vector shaped [batch_size] indicating whether each sample is an end token. |
next_inputs_fn | (Optional) A callable that takes sample_ids and returns the next batch of inputs. If not provided, sample_ids is used as the next batch of inputs. |
... | A list that contains other common arguments for layer creation. |
None
A helper to use during inference with a custom sampling function.