Intro

First, we need to install blurr module for Transformers integration.

reticulate::py_install('ohmeow-blurr',pip = TRUE)

Binary task

Grab data for binary classification:

Define task:

HF_TASKS_AUTO = HF_TASKS_AUTO()
task = HF_TASKS_AUTO$SequenceClassification

pretrained_model_name = "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task)
Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s]
Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s]
Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s]

Dataloader

Create Learner with Hugging Face data blocks:

imdb_df = data.table::fread('imdb_sample/texts.csv')

blocks = list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock())

dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('text'),
                   get_y=ColReader('label'),
                   splitter=ColSplitter(col='is_valid'))

dls = dblock %>% dataloaders(imdb_df, bs=4)
dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[    0,  4833,  3009,  ...,  1916,     6,     2],
        [    0,  1876, 13856,  ...,     7,    47,     2],
        [    0,  2647,     6,  ...,     6,    61,     2],
        [    0,    20,  2091,  ...,  5779,    30,     2]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')


[[2]]
TensorCategory([0, 1, 0, 0], device='cuda:0')

RoBERTa model

Wrap model:

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls,
                model,
                opt_func=partial(Adam, decouple_wd=TRUE),
                loss_func=CrossEntropyLossFlat(),
                metrics=accuracy,
                cbs=HF_BaseModelCallback(),
                splitter=hf_splitter())

learn$create_opt()
learn$freeze()

learn %>% summary()
epoch   train_loss   valid_loss   accuracy   time  
------  -----------  -----------  ---------  ------
HF_BaseModelWrapper (Input shape: 4 x 512)
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Embedding            4 x 512 x 768        38,603,520 False     
________________________________________________________________
Embedding            4 x 512 x 768        394,752    False     
________________________________________________________________
Embedding            4 x 512 x 768        768        False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 768              590,592    True      
________________________________________________________________
Dropout              4 x 768              0          False     
________________________________________________________________
Linear               4 x 2                1,538      True      
________________________________________________________________

Total params: 124,647,170
Total trainable params: 630,530
Total non-trainable params: 124,016,640

Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fd850db18c8>, decouple_wd=True)
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - HF_BaseModelCallback

Conclusion

Train and predict:

result = learn %>% fit_one_cycle(3, lr_max=1e-3)

learn %>% predict(imdb_df$text[1:4])