Intro

First, we need to install blurr module for Transformers integration.

reticulate::py_install('ohmeow-blurr',pip = TRUE)

Dataset

Get dataset from the link:

library(fastai)
library(magrittr)
library(zeallot)
squad_df = data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/squad_sample.csv')

And load pretrained BERT for question answering from transformers library:

pretrained_model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'

hf_model_cls = transformers$BertForQuestionAnswering

c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% 
  get_hf_objects(pretrained_model_name, model_cls=hf_model_cls)

preprocess = partial(pre_process_squad(), hf_arch=hf_arch, hf_tokenizer=hf_tokenizer)

Preprocess

Prepare dataset for fastai:

squad_df = data.table::as.data.table(squad_df %>% py_apply(preprocess))
max_seq_len = 128

tibble::tibble(squad_df)

squad_df[, c(8,10:12)] = lapply(squad_df[, c(8,10:12)], function(x) unlist(as.vector(x)))
squad_df = squad_df[is_impossible == FALSE & tokenized_input_len < max_seq_len]
vocab = c(1:max_seq_len)

Datalaoder

Crate datalaoder. But at first, create getters (how we will pick our columns):

trunc_strat = ifelse(hf_tokenizer$padding_side == 'right', 'only_second', 'only_first')

before_batch_tfm = HF_QABeforeBatchTransform(hf_arch, hf_tokenizer,
                                             max_length = max_seq_len,
                                             truncation = trunc_strat,
                                             tok_kwargs = list('return_special_tokens_mask' = TRUE))

blocks = list(
  HF_TextBlock(before_batch_tfms=before_batch_tfm, input_return_type=HF_QuestionAnswerInput),
  CategoryBlock(vocab=vocab),
  CategoryBlock(vocab=vocab)
)

# question and context
get_x = function(x) {
  if(hf_tokenizer$padding_side == 'right') {
    list(x[['question']], x[['context']])
  } else {
    list(x[['context']], x[['question']])
  }
}

dblock = DataBlock(blocks=blocks,
                   get_x=get_x,
                   get_y=list(ColReader('tok_answer_start'), ColReader('tok_answer_end')),
                   splitter=RandomSplitter(),
                   n_inp=1)

dls = dblock %>% dataloaders(squad_df, bs=4)

dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[  101, 20773,  2207,  1996,  2299,  1000,  1000,  4195,  1000,  1000,
          2006,  2029,  3784,  2189,  2326,  1029,   102,  2006,  2337,  1020,
          1010,  2355,  1010,  2028,  2154,  2077,  2014,  2836,  2012,  1996,
          3565,  4605,  1010, 20773,  2207,  1037,  2047,  2309,  7580,  2006,
          2189, 11058,  2326, 15065,  2170,  1000,  1000,  4195,  1000,  1000,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  2019,  2590,  4895,  6465,  2165,  2173,  2043,  1029,   102,
          1999,  2325, 20773,  2772,  2019,  2330,  3661,  2029,  1996,  2028,
          3049,  2018,  2042,  9334, 16442,  2005,  1025,  1996,  3661,  2001,
          8280,  2000, 10413, 21442, 11705,  1998, 25930,  8820, 13471,  2050,
         21469, 10631,  3490,  1011, 16950,  2863,  1010, 14328,  2068,  2000,
          3579,  2006,  2308,  2004,  2027,  3710,  2004,  1996,  2132,  1997,
          1996,  1043,  2581,  1999,  2762,  1998,  1996,  8740,  1999,  2148,
          3088,  4414,  1010,  2029,  2097,  2707,  2000,  2275,  1996, 18402,
          1999,  2458,  4804,  2077,  1037,  2364,  4895,  6465,  1999,  2244,
          2325,  2008,  2097,  5323,  2047,  2458,  3289,  2005,  1996,  4245,
          1012,   102],
        [  101, 20773,  2247,  2007,  6108,  1062,  2777,  2007,  3183,  1005,
          1055,  2155,  2044,  2037,  2331,  1029,   102,  2206,  1996,  2331,
          1997, 15528,  3897,  1010, 20773,  1998,  6108,  1011,  1062,  1010,
          2426,  2060,  3862,  4481,  1010,  2777,  2007,  2010,  2155,  1012,
          2044,  1996, 10219,  1997, 13337,  1997,  3897,  1005,  1055,  2331,
          1010, 20773,  1998,  6108,  1011,  1062,  6955,  5190,  1997,  6363,
          2000, 15358,  2068,  2041,  1012,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  2129,  2038,  2023,  3192,  2904,  1999,  3522,  2086,  1029,
           102,  2044,  7064, 16864,  1999,  2384,  1010, 20773,  1998, 20539,
          2631,  1996, 12084,  3192,  2000,  3073, 17459,  3847,  2005,  5694,
          1999,  1996,  5395,  2181,  1010,  2000,  2029, 20773,  5201,  2019,
          3988,  1002,  5539,  1010,  2199,  1012,  1996,  3192,  2038,  2144,
          4423,  2000,  2147,  2007,  2060, 15430,  1999,  1996,  2103,  1010,
          1998,  2036,  3024,  4335,  2206,  7064, 25209,  2093,  2086,  2101,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]], device='cuda:0')

[[1]]$token_type_ids
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], device='cuda:0')

[[1]]$special_tokens_mask
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], device='cuda:0')

[[1]]$cls_index
tensor([[0],
        [0],
        [0],
        [0]], device='cuda:0')

[[1]]$p_mask
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]], device='cuda:0')


[[2]]
TensorCategory([42, 88, 20, 49], device='cuda:0')

[[3]]
TensorCategory([43, 90, 22, 55], device='cuda:0')

Train

Wrap the model and fit:

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls,
                model,
                opt_func=partial(Adam, decouple_wd=T),
                cbs=HF_QstAndAnsModelCallback(),
                splitter=hf_splitter())

learn$loss_func=MultiTargetLoss()
learn$create_opt()                # -> will create your layer groups based on your "splitter" function
learn$freeze()

learn %>% fit_one_cycle(4, lr_max=1e-3)

Conclusion

Lets create a dataset and predict with learn model:

inf_df = data.frame( 'question'= 'When was Star Wars made?',
                     'context'= 'George Lucas created Star Wars in 1977. He directed and produced it.')
bert_answer = function(inf_df) {
  test_dl = dls$test_dl(inf_df)
  inp = test_dl$one_batch()[[1]]['input_ids']

  res = learn %>% predict(inf_df)

  # as_array is a function to turn a torch tensor to R array
  sapply(res[[3]],as_array)

  hf_tokenizer$convert_ids_to_tokens(inp[[1]]$tolist()[[1]],
                                     skip_special_tokens=FALSE)[sapply(res[[3]],as_array)+1]
  # [sapply(res[[3]],as_array)+1] here +1 because tensor starts from 0 but R from 1
}

Result:

cat(bert_answer(inf_df))
# in 1977