MultiHead Attention layer.

layer_multi_head_attention(
  object,
  head_size,
  num_heads,
  output_size = NULL,
  dropout = 0,
  use_projection_bias = TRUE,
  return_attn_coef = FALSE,
  kernel_initializer = "glorot_uniform",
  kernel_regularizer = NULL,
  kernel_constraint = NULL,
  bias_initializer = "zeros",
  bias_regularizer = NULL,
  bias_constraint = NULL,
  ...
)

Arguments

object

Model or layer object

head_size

int, dimensionality of the `query`, `key` and `value` tensors after the linear transformation.

num_heads

int, number of attention heads.

output_size

int, dimensionality of the output space, if `NULL` then the input dimension of `value` or `key` will be used, default `NULL`.

dropout

float, `rate` parameter for the dropout layer that is applied to attention after softmax, default `0`.

use_projection_bias

bool, whether to use a bias term after the linear output projection.

return_attn_coef

bool, if `TRUE`, return the attention coefficients as an additional output argument.

kernel_initializer

initializer, initializer for the kernel weights.

kernel_regularizer

regularizer, regularizer for the kernel weights.

kernel_constraint

constraint, constraint for the kernel weights.

bias_initializer

initializer, initializer for the bias weights.

bias_regularizer

regularizer, regularizer for the bias weights.

bias_constraint

constraint, constraint for the bias weights.

...

additional parameters to pass

Value

A tensor

Details

Defines the MultiHead Attention operation as defined in [Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes in a `query`, `key` and `value` tensors returns the dot-product attention between them.

Examples

if (FALSE) { mha = layer_multi_head_attention(head_size=128, num_heads=128) query = tf$random$uniform(list(32L, 20L, 200L)) # (batch_size, query_elements, query_depth) key = tf$random$uniform(list(32L, 15L, 300L)) # (batch_size, key_elements, key_depth) value = tf$random$uniform(list(32L, 15L, 400L)) # (batch_size, key_elements, value_depth) attention = mha(list(query, key, value)) # (batch_size, query_elements, value_depth) # If `value` is not given then internally `value = key` will be used: mha = layer_multi_head_attention(head_size=128, num_heads=128) query = tf$random$uniform(list(32L, 20L, 200L)) # (batch_size, query_elements, query_depth) key = tf$random$uniform(list(32L, 15L, 300L)) # (batch_size, key_elements, key_depth) attention = mha(list(query, key)) # (batch_size, query_elements, value_depth) }