指点成金-最美分享吧

登录

transformers自定义模型的保存和加载

佚名 举报

篇首语:本文由小编为大家整理,主要介绍了transformers自定义模型的保存和加载相关的知识,希望对你有一定的参考价值。

step1 保存 (my_plbart.py)

#如果一开始用了并行训练最好加上这句model_to_save = model.module if hasattr(model, "module") else model#这样保存的是模型参数,记得格式是.pttorch.save(model_to_save.state_dict(),output_model_dir+"model-2.pt")

step2 加载 (use_plbart.py)

#因为是自定义模型呀model = Model()#拿到保存的参数model_static_dict = torch.load(output_model_dir+"model-2.pt")#把参数加载到模型中model.load_state_dict(model_static_dict)

注意:

两个文件中的 output_model_dir 路径和Model类应该是一致的。

话外:

如果你的模型不是自定义的,而是直接用的transformers中from_pretrained得到的,那么可以直接用save_pretrained进行保存。以上提供的是更一般化的方法,即torch对模型参数保存和加载的支持。

附上完整的模型文件 only_model.py

import torchfrom transformers import PLBartConfig, PLBartModel, PLBartTokenizerplbart_hf_path = "uclanlp/plbart-multi_task-java"plbart_local_path = "your_path/plbart_files"output_model_dir = "your_path/PLBART_huggingface/finetuned_models/"checkpoint = plbart_local_pathmyTokenizer = PLBartTokenizer.from_pretrained(checkpoint)class Model(torch.nn.Module):    def __init__(self):        super().__init__()                self.pretrained = PLBartModel.from_pretrained(checkpoint)        # 定义一组值全为0的常量        self.register_buffer(            "final_logits_bias",            torch.zeros(1, myTokenizer.vocab_size)        )        self.fc = torch.nn.Linear(768, myTokenizer.vocab_size, bias=False)        # 加载预训练模型的参数        parameters = PLBartConfig()        # self.fc.load_state_dict(parameters.lm_head.state_dict())        self.criterion = torch.nn.CrossEntropyLoss()    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):        logits = self.pretrained(            input_ids=input_ids,            attention_mask=attention_mask,            decoder_input_ids=decoder_input_ids        )        logits = logits.last_hidden_state        logits = self.fc(logits)+self.final_logits_bias        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())        return "loss": loss, "logits": logits

(only_model.py被其他两个py引用,单拎出来形成一个模型文件的好处是,如果直接用use_plbart.py引用my_plbart.py,还会引用进很多无关的代码,Maybe非常耗时甚至卡住)

以上是关于transformers自定义模型的保存和加载的主要内容,如果未能解决你的问题,请参考以下文章