Intro

First, we need to install blurr module for Transformers integration.

reticulate::py_install('ohmeow-blurr',pip = TRUE)

Dataset

Get dataset from the link:

library(fastai)
library(magrittr)
library(zeallot)

cnndm_df = data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/cnndm_sample.csv')

Pretrained model

Get pretrained BART model from transformers:

transformers = transformers()

BartForConditionalGeneration = transformers$BartForConditionalGeneration

pretrained_model_name = "facebook/bart-large-cnn"
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-%
  get_hf_objects(pretrained_model_name,model_cls=BartForConditionalGeneration)

Datablock

Create datablock and see batch:

before_batch_tfm = HF_SummarizationBeforeBatchTransform(hf_arch, 
                                                        hf_tokenizer, max_length=c(256, 130))
blocks = list(HF_Text2TextBlock(before_batch_tfms=before_batch_tfm, 
                                input_return_type=HF_SummarizationInput), noop())

dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('article'),
                   get_y=ColReader('highlights'),
                   splitter=RandomSplitter())

dls = dblock %>% dataloaders(cnndm_df, bs=2)

dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[    0,    36, 16256,    43,   480,  2193,     7,    62,     7,   158,
           135,     9,    70,   684,  4707,     6,  1625,    16,  4984,    25,
            65,     9,     5,   144, 39865, 32273,  3806,    15,     5,  5518,
             4,    20,  9544,  3455,     9,  2147,   464,     8,  1050, 29779,
         26284,    15,  1632, 11534,    32,     6,   959,     6,  5608,     5,
          8066,     9,     5,   247,    18,  4066,  7892,     4,   178,    89,
            16,    10,   372,   432,     7,  2217,     4,    96,     5,   315,
          3076,  9356,  4928,    36,  4154,  9662,    43,   623, 12978, 23853,
          2521,    18,   889,     9, 10721,   625, 32273,   749,  1625,  5546,
           365,   212,     4,    20,   889,  3372,    10,   333,     9,   601,
           749,    14, 19655,     5,  1647,     9,     5,  3875,    18,  4707,
             8,    32,  3891,  1687,  2778, 39865, 32273,     4,  1740,    63,
         23491, 28919,    11,     5,  7599,  3939,     7,    63, 10602, 41166,
          1634,    11, 12718,  1115,   281,     8,     5,   854,  3964, 21785,
         14113,     8,    63, 38905,     8, 21950, 19947,    11,     5,  1926,
             6,  1625, 10902,    41,  5784,  4066,  3143,     9, 39456,     8,
           856, 25729,     4,   993,   195,  5243,    66,     9,   262,  1360,
         38950,  1848,  4707,   303,    11,  1625,   480,     5,   144,    11,
           143,   247,   480,    64,   129,    28, 13590,   624,    63,  7562,
             4,    85,    16,   184,     7, 39179,  3505,     9, 27772,     6,
         28482,  4707,     9,  7723,     6,   112,     6,  6115, 17576,     9,
          7723,     8,   973,     6,   151,  1380, 14868,     9,  3451,     4,
           221,  2839,  8367,   102,     6,    10,   786,    12,  7699,  1651,
            14,  1364,     7,  3720,  8360,     8,  5068,   709,    11,  1625,
             6,    34,  3919,   411,  4707,    61,    24,   161,  7648,  2072,
             5,  1272,  2713,    30,     5,     2],
        [    0,    36, 16256,    43,   480,  6871,    24,  1655, 22049,    50,
         41571,  1896,    54, 24509,    13,   244,     5,   363,     5,   601,
            12,   180,    12,   279,  1896,    21,   738,  1462,   116,   280,
           115,  6723,    15,    61,   985,     5,  3940,  2046,     4,  1868,
         22049,    18,     8,  1896,    18,  8826,  2327,   117, 28946,   273,
            11,  2559,   461,  4961,    25,     7,  1060, 28604,  2236,    16,
          1317, 11347,   148,    10,  7158,   486,    31,    14,   902,   973,
             6,  1125,     6,   363,    11, 23058,     6,  1261,    35,  4028,
            26,    24,    21,    69,   979,     4,   280, 34920,   480,    19,
          5767,  3809,  1243, 21605, 13875,    24,    21,    69,   979,     6,
         41571,     6,    54, 16670,    66,     6,   150, 15151,  2459, 22049,
            26,    24,    21,    69,   979,     6,  1655,     6,    54,    21,
         16600,    71,   145,  4487,    30,     5,  6066,   480,    21,  1353,
             7,   273,    18,   461,  7069,     6,     8,  1353,     7,     5,
           200,    12,  5743,  1900,   403, 25451,    11,  1353,  1261,     4,
         22049,    34,  4407,    45,  2181,     8,  1695,    37,   738,     5,
          7044,    11,  1403,    12, 21409,     4,    20,  7158,   486,   702,
          2330,    11,   461,    15,   273,     6,    39,  3969,  2026,     6,
           124,    62,    49, 19395,    14,    24,    21,  1896,     6,     8,
            45,    49,  3653,     6,    54,    21,     5, 29593,   368,     4,
          4500,  4945,   628,   273,  1390,     6, 15151,  2459, 22049,    26,
            79,    21,   686,  1655,    21,     5,    65, 16600,     4,  2612,
           116, 41039, 10105,    37,    18,   127,   979, 39732,   264,  7173,
         41039,  1250,     9,     5,  1065, 48149,    77,   553,   549,    79,
            56,   655,   137,  1317,    69,  1655, 22049,  7923, 24120,    50,
          8930,    66,    13,   244,     4,     2]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

[[1]]$decoder_input_ids
tensor([[    0,  1625,  4452,     7,    62,     7,   158,   135,     9,    70,
           684,  4707,    15,  3875,   479, 50118,   243,    16,   184,     7,
         39179,  3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,
             6,   151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,
          2147,   464,    16,  9405,    10,   380,  8793,    15,    63, 28809,
           479, 50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,
             9,   145,     5,   247,    18,   632,  7648,   479,     2,     1,
             1,     1,     1,     1,     1,     1],
        [    0,  5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,
            45,     5,  7158,   486,     6,    40,  3094,     5,   403,   479,
         50118,  5341,    35,    83,  2470,    13,  1896,    18,   284,   161,
            37,  4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,
           534, 15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,
            16,    14,     9,    69,   979,   479, 50118, 30541,     6, 41571,
          1896,    18, 18631,  1843,    26,    14,    24,    69,   979,    18,
          2236,    15,     5,  7158,   486,   479]], device='cuda:0')

[[1]]$labels
tensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,
          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,
          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,
           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,
           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,
         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,
           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100],
        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,
             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,
          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,
          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,
         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,
            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,
            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,
            15,     5,  7158,   486,   479,     2]], device='cuda:0')


[[2]]
tensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,
          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,
          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,
           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,
           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,
         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,
           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100],
        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,
             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,
          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,
          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,
         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,
            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,
            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,
            15,     5,  7158,   486,   479,     2]], device='cuda:0')

Train

Define a Learner object and train the model:

text_gen_kwargs = hf_config$task_specific_params['summarization'][[1]]
text_gen_kwargs['max_length'] = 130L; text_gen_kwargs['min_length'] = 30L

text_gen_kwargs

model = HF_BaseModelWrapper(hf_model)
model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

learn = Learner(dls,
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=model_cb,
                splitter=partial(summarization_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn$create_opt()
learn$freeze()

learn %>% fit_one_cycle(1, lr_max=4e-5)

Conclusion

Predict a new text:

test_article = c([1448 chars quoted with '"'])
outputs = learn$blurr_summarize(test_article, num_return_sequences=3L)
cat(outputs)

 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, police chief says .


 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, an officer says.


 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .