This is an implementation of the AdamW optimizer described in "Decoupled Weight Decay Regularization" by Loshchilov & Hutter (https://arxiv.org/abs/1711.05101) ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). It computes the update step of tf.keras.optimizers.Adam and additionally decays the variable. Note that this is different from adding L2 regularization on the variables to the loss: it regularizes variables with large gradients more than L2 regularization would, which was shown to yield better training loss and generalization error in the paper above.

optimizer_decay_adamw(
  weight_decay,
  learning_rate = 0.001,
  beta_1 = 0.9,
  beta_2 = 0.999,
  epsilon = 1e-07,
  amsgrad = FALSE,
  name = "AdamW",
  clipnorm = NULL,
  clipvalue = NULL,
  decay = NULL,
  lr = NULL
)

Arguments

weight_decay

A Tensor or a floating point value. The weight decay.

learning_rate

A Tensor or a floating point value. The learning rate.

beta_1

A float value or a constant float tensor. The exponential decay rate for the 1st moment estimates.

beta_2

A float value or a constant float tensor. The exponential decay rate for the 2nd moment estimates.

epsilon

A small constant for numerical stability. This epsilon is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper.

amsgrad

boolean. Whether to apply AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and beyond".

name

Optional name for the operations created when applying

clipnorm

is clip gradients by norm.

clipvalue

is clip gradients by value.

decay

is included for backward compatibility to allow time inverse decay of learning rate.

lr

is included for backward compatibility, recommended to use learning_rate instead.

Value

Optimizer for use with `keras::compile()`

Examples

if (FALSE) { step = tf$Variable(0L, trainable = FALSE) schedule = tf$optimizers$schedules$PiecewiseConstantDecay(list(c(10000, 15000)), list(c(1e-0, 1e-1, 1e-2))) lr = 1e-1 * schedule(step) wd = lambda: 1e-4 * schedule(step) }