Contents

⚡️+🤗:Transformers + Pytorch Lightning 最佳实战

⚡️+🤗:Transformers + Pytorch Lightning 最佳实战

背景

Pytorch Lightning 是 pytorch 的一个上层封装,类似于 Keras 之于 tensorflow。我们希望使用这样一个优秀的框架来实现对 huggingface 模型的快速训练和开发测试。

数据集 Dataset

对于任何一个 dataset,将其拆分为有监督训练的 src 和 tgt 两个 label

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DataModule(Dataset):  
    def __init__(  
        self,  
        data: pd.DataFrame,  
        tokenizer: PreTrainedTokenizer,  
        source_max_token_len: int,  
        target_max_token_len: int,  
    ) -> None:  
        """ Pytorch Dataset Module for input data  
        Args:            data (pd.DataFrame): Dataframe containing input data            tokenizer (PreTrainedTokenizer): Tokenizer for encoding input data            source_max_token_len (int): Max token length for source text            target_max_token_len (int): Max token length for target text        """        self.data = data  
        self.tokenizer = tokenizer  
        self.source_max_token_len = source_max_token_len  
        self.target_max_token_len = target_max_token_len  
  
    def __len__(self):  
        return len(self.data)  
  
    def __getitem__(self, index):  
        data_row = self.data.iloc[index]  
  
        src_text_encoding = self.tokenizer(  
            data_row["src"],  
            max_length=self.source_max_token_len,  
            padding="max_length",  
            truncation=True,  
            return_tensors="pt",  
            return_attention_mask=True,  
            add_special_tokens=True  
        )  
        tgt_text_encoding = self.tokenizer(  
            data_row['tgt'],  
            max_length=self.target_max_token_len,  
            padding="max_length",  
            truncation=True,  
            return_attention_mask=True,  
            return_tensors="pt",  
            add_special_tokens=True  
        )  
        labels = tgt_text_encoding["input_ids"]  
        labels[labels == 0] = -100  
  
        return dict(  
            source_text_input_ids=src_text_encoding["input_ids"].flatten(),  
            source_text_attention_mask=src_text_encoding["attention_mask"].flatten(),  
            labels=labels.flatten(),  
            labels_attention_mask=tgt_text_encoding["attention_mask"].flatten(),  
        )

然后构建 dataset module,组合 training data 和 valid data

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class DatasetModule(pl.LightningDataModule):  
    def __init__(  
        self,  
        train_df: pd.DataFrame,  
        valid_df: pd.DataFrame,  
        tokenizer: PreTrainedTokenizer,  
        batch_size: int = 4,  
        source_max_token_len: int = 512,  
        target_max_token_len: int = 512,  
        num_workers: int = 4,  
        shuffle: bool = True,  
    ) -> None:  
        super().__init__()  
        self.tokenizer = tokenizer  
        self.train_df = train_df  
        self.valid_df = valid_df  
        self.batch_size = batch_size  
        self.source_max_token_len = source_max_token_len  
        self.target_max_token_len = target_max_token_len  
        self.num_workers = num_workers  
        self.shuffle = shuffle  
  
    def setup(self, stage=None):  
        self.train_dataset = DataModule(  
            self.train_df,  
            self.tokenizer,  
            self.source_max_token_len,  
            self.target_max_token_len  
        )  
  
        self.valid_dataset = DataModule(  
            self.valid_df,  
            self.tokenizer,  
            self.source_max_token_len,  
            self.target_max_token_len  
        )  
  
    def train_dataloader(self):  
        return DataLoader(  
            self.train_dataset,  
            batch_size=self.batch_size,  
            shuffle=self.shuffle,  
            num_workers=self.num_workers,  
            pin_memory=True,  
        )  
  
    def val_dataloader(self):  
        return DataLoader(  
            self.valid_dataset,  
            batch_size=self.batch_size,  
            shuffle=False,  
            num_workers=self.num_workers,  
            pin_memory=True,  
        )  
  
    def test_dataloader(self):  
        return DataLoader(  
            self.valid_dataset,  
            batch_size=self.batch_size,  
            shuffle=False,  
            num_workers=self.num_workers,  
            pin_memory=True  
        )

模型

我们以 T5 为例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class T5Seq2SeqLM(T5ForConditionalGeneration, ABC):  
    def __init__(self, config: T5Config, *args, **kwargs):  
        super().__init__(config)  
  
    def hidden_state_sample(self, hidden_states, attention_mask) -> AutoEncoderOutput:  
        """ Sample from the hidden state distribution  
        """        return AutoEncoderOutput(  
            hidden_states=hidden_states,  
            attention_mask=attention_mask,  
        )  
  
    def compute_loss(self, lm_logits, labels, ae_output: AutoEncoderOutput):  
        """ Compute the loss  
        Args:            lm_logits (torch.FloatTensor): logits from the language model            labels (torch.LongTensor): labels for the language model            ae_output (AutoEncoderOutput): output from the autoencoder        Returns:            torch.FloatTensor: loss        """        loss = None  
        if labels is not None:  
            loss_fct = CrossEntropyLoss(ignore_index=-100)  
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))  
        return loss

    def prepare_inputs_for_generation(  
            self,  
            input_ids,  
            past=None,  
            attention_mask=None,  
            head_mask=None,  
            decoder_head_mask=None,  
            # cross_attn_head_mask=None,  
            use_cache=None,  
            encoder_outputs=None,  
            **kwargs  
    ):  
        if past is not None:  
            input_ids = input_ids[:, -1:]  
        return {  
            "decoder_input_ids": input_ids,  
            "past_key_values": past,  
            "encoder_outputs": encoder_outputs,  
            "attention_mask": attention_mask,  
            "head_mask": head_mask,  
            "decoder_head_mask": decoder_head_mask,  
            # "cross_attn_head_mask": cross_attn_head_mask,  
            "use_cache": use_cache,  
        }

训练

首先我们需要构建一个 HFModel 的基础框架,给 pytorch lightning 调用

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class HFModel(pl.LightningModule):  
    def __init__(  
            self,  
            tokenizer,  
            model,  
            config: Dict,  
    ) -> None:  
        super().__init__()  
        self.model = model  
        self.tokenizer = tokenizer  
        self.average_training_loss = None  
        self.average_validation_loss = None  
        self.config = config  
  
    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):  
        output = self.model(  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            decoder_attention_mask=decoder_attention_mask,  
            labels=labels,  
        )  
        return output.loss, output.logits  
  
    def compute_loss(self, batch, batch_size):  
        input_ids = batch['source_text_input_ids']  
        attention_mask = batch['source_text_attention_mask']  
        labels_attention_mask = batch['labels_attention_mask']  
        labels = batch['labels']  
  
        loss, logits = self(  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            labels=labels,  
            decoder_attention_mask=labels_attention_mask  
        )  
        return loss  
  
    def training_step(self, batch, batch_size):  
        loss = self.compute_loss(batch, batch_size)  
        self.log("train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)  
        return loss  
  
    def validation_step(self, batch, batch_size):  
        loss = self.compute_loss(batch, batch_size)  
        self.log("val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)  
        return loss  
  
    def test_step(self, batch, batch_size):  
        loss = self.compute_loss(batch, batch_size)  
        self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True)  
        return loss  
  
    def configure_optimizers(self):  
        optimizer = AdamW(  
            self.parameters(),  
            lr=self.config['train']['learning_rate']  
        )        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(  
            optimizer, mode='min', factor=0.5, patience=1, verbose=True  
        )  
        return {  
            "optimizer": optimizer,  
            "lr_scheduler": {  
                "scheduler": scheduler,  
                "monitor": "val_loss"  
            }  
        }

transformers 最大的不同在于它的 checkpoint,需要 save_pretrained 和 from_pretrained,因此我们要设计一个单独的 checkpoint 方法,即 HFTrainer

1
2
3
4
5
6
7
class HFTrainer(pl.Trainer):  
    def save_checkpoint(self, filepath: Union[str, Path], weights_only: bool = False, storage_options: Optional[Any] = None) -> None:  
        if self.is_global_zero:  
            dpath = os.path.splitext(filepath)[0].replace("=", '')  
            lightning_model = self.model.module.module  
            lightning_model.model.save_pretrained(dpath)  
            lightning_model.tokenizer.save_pretrained(dpath)

最后进行数据集构造和训练方法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def train(  
        model,  
        tokenizer,  
        train_df: pd.DataFrame,  
        eval_df: pd.DataFrame,  
        config: Dict = None,  
):  
    hf_model = HFModel(  
        tokenizer=tokenizer,  
        model=model,  
        config=config  
    )  
    dataset = DatasetModule(  
        train_df=train_df,  
        valid_df=eval_df,  
        tokenizer=tokenizer,  
        batch_size=config['train']['batch_size'],  
        source_max_token_len=config['dataset']['source_max_token_len'],  
        target_max_token_len=config['dataset']['target_max_token_len'],  
        num_workers=config['train']['dataloader_num_workers'],  
    )  
  
    callbacks = [  
        RichProgressBar(refresh_rate=5),  
        EarlyStopping(  
            monitor="val_loss",  
            patience=config['train']['early_stop_patience'],  
            min_delta=0.00,  
            verbose=True,  
            mode='min',  
        )]  
    checkpoint_callback = ModelCheckpoint(  
        dirpath=os.path.join(  
            config['train']['save_dir'],  
            config['model']['name']  
        ),  
        filename="{epoch}-{step}-{val_loss:.2f}-{train_loss:.2f}",  
        monitor='val_loss',  
    )  
    callbacks.append(checkpoint_callback)  
  
    trainer = HFTrainer(  
        logger=pl_logger.TensorBoardLogger(save_dir=config['train']['log_dir']),  
        callbacks=callbacks,  
        max_epochs=config['train']['max_epoch'],  
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',  
        devices=-1,  
        precision=16,  
        auto_select_gpus=True,  
        log_every_n_steps=1,  
        strategy=DDPStrategy(find_unused_parameters=False),  
        accumulate_grad_batches=config['train']['accumulate_grad_batches'],  
        auto_lr_find=True,  
    )  
    trainer.fit(hf_model, dataset)