First, we need to install blurr module
for Transformers
integration.
reticulate::py_install('ohmeow-blurr',pip = TRUE)
Grab data for binary classification:
Define task:
HF_TASKS_AUTO = HF_TASKS_AUTO()
task = HF_TASKS_AUTO$SequenceClassification
pretrained_model_name = "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task)
Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s]
Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s]
Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s]
Create Learner
with Hugging Face data
blocks:
imdb_df = data.table::fread('imdb_sample/texts.csv')
blocks = list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock())
dblock = DataBlock(blocks=blocks,
get_x=ColReader('text'),
get_y=ColReader('label'),
splitter=ColSplitter(col='is_valid'))
dls = dblock %>% dataloaders(imdb_df, bs=4)
dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[ 0, 4833, 3009, ..., 1916, 6, 2],
[ 0, 1876, 13856, ..., 7, 47, 2],
[ 0, 2647, 6, ..., 6, 61, 2],
[ 0, 20, 2091, ..., 5779, 30, 2]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='cuda:0')
[[2]]
TensorCategory([0, 1, 0, 0], device='cuda:0')
Wrap model:
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
model,
opt_func=partial(Adam, decouple_wd=TRUE),
loss_func=CrossEntropyLossFlat(),
metrics=accuracy,
cbs=HF_BaseModelCallback(),
splitter=hf_splitter())
learn$create_opt()
learn$freeze()
learn %>% summary()
epoch train_loss valid_loss accuracy time
------ ----------- ----------- --------- ------
HF_BaseModelWrapper (Input shape: 4 x 512)
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 4 x 512 x 768 38,603,520 False
________________________________________________________________
Embedding 4 x 512 x 768 394,752 False
________________________________________________________________
Embedding 4 x 512 x 768 768 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 768 590,592 True
________________________________________________________________
Dropout 4 x 768 0 False
________________________________________________________________
Linear 4 x 2 1,538 True
________________________________________________________________
Total params: 124,647,170
Total trainable params: 630,530
Total non-trainable params: 124,016,640
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fd850db18c8>, decouple_wd=True)
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- Recorder
- ProgressCallback
- HF_BaseModelCallback
Train and predict:
result = learn %>% fit_one_cycle(3, lr_max=1e-3)
learn %>% predict(imdb_df$text[1:4])