Python 在 PyTorch 中保存训练模型的最佳方法?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/42703500/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 22:02:37  来源:igfitidea点击:

Best way to save a trained model in PyTorch?

pythonserializationdeep-learningpytorchtensor

提问by Wasi Ahmad

I was looking for alternative ways to save a trained model in PyTorch. So far, I have found two alternatives.

我一直在寻找在 PyTorch 中保存训练模型的替代方法。到目前为止,我找到了两种选择。

  1. torch.save()to save a model and torch.load()to load a model.
  2. model.state_dict()to save a trained model and model.load_state_dict()to load the saved model.
  1. torch.save()保存模型,torch.load()加载模型。
  2. model.state_dict()保存训练好的模型,model.load_state_dict()加载保存的模型。

I have come across to this discussionwhere approach 2 is recommended over approach 1.

我遇到过这个讨论,建议方法 2 优于方法 1。

My question is, why the second approach is preferred? Is it only because torch.nnmodules have those two function and we are encouraged to use them?

我的问题是,为什么首选第二种方法?仅仅是因为torch.nn模块具有这两个功能并且我们被鼓励使用它们吗?

回答by dontloo

I've found this pageon their github repo, I'll just paste the content here.

我在他们的 github repo 上找到了这个页面,我会把内容贴在这里。



Recommended approach for saving a model

保存模型的推荐方法

There are two main approaches for serializing and restoring a model.

序列化和恢复模型有两种主要方法。

The first (recommended) saves and loads only the model parameters:

第一个(推荐)只保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

Then later:

后来:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

第二个保存并加载整个模型:

torch.save(the_model, PATH)

Then later:

后来:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

但是在这种情况下,序列化数据绑定到特定的类和所使用的确切目录结构,因此在其他项目中使用时,或者经过一些严重的重构后,它可能会以各种方式中断。

回答by Jadiel de Armas

It depends on what you want to do.

这取决于你想做什么。

Case # 1: Save the model to use it yourself for inference: You save the model, you restore it, and then you change the model to evaluation mode. This is done because you usually have BatchNormand Dropoutlayers that by default are in train mode on construction:

案例#1:保存模型以供自己用于推理:您保存模型,恢复它,然后将模型更改为评估模式。这样做是因为你平时有BatchNormDropout层,默认情况下是在建设训练模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Case # 2: Save model to resume training later: If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:

案例#2:保存模型以在稍后恢复训练:如果您需要继续训练您将要保存的模型,您需要保存的不仅仅是模型。您还需要保存优化器的状态、时期、分数等。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

To resume training you would do things like: state = torch.load(filepath), and then, to restore the state of each individual object, something like this:

要恢复训练,您将执行以下操作:state = torch.load(filepath),然后恢复每个单独对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Since you are resuming training, DO NOTcall model.eval()once you restore the states when loading.

由于您正在恢复训练,因此在加载时恢复状态后请勿调用model.eval()

Case # 3: Model to be used by someone else with no access to your code: In Tensorflow you can create a .pbfile that defines both the architecture and the weights of the model. This is very handy, specially when using Tensorflow serve. The equivalent way to do this in Pytorch would be:

案例#3:模型被其他人使用而无法访问您的代码:在 Tensorflow 中,您可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow serve. 在 Pytorch 中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn't recommend it.

这种方式仍然不是防弹的,而且由于 pytorch 仍在经历很多变化,我不推荐它。

回答by prosti

The picklePython library implements binary protocols for serializing and de-serializing a Python object.

泡菜的Python库实现二进制协议的序列化和反序列化Python对象。

When you import torch(or when you use PyTorch) it will import picklefor you and you don't need to call pickle.dump()and pickle.load()directly, which are the methods to save and to load the object.

当您import torch(或当您使用 PyTorch 时)它会import pickle为您服务,您不需要直接调用pickle.dump()pickle.load(),它们是保存和加载对象的方法。

In fact, torch.save()and torch.load()will wrap pickle.dump()and pickle.load()for you.

事实上,torch.save()torch.load()将包裹pickle.dump()pickle.load()为您服务。

A state_dictthe other answer mentioned deserves just few more notes.

一个state_dict对方的回答值得提及的只是几个音符。

What state_dictdo we have inside PyTorch? There are actually two state_dicts.

什么state_dict我们有内部PyTorch?实际上有两个state_dicts。

The PyTorch model is torch.nn.Modulehas model.parameters()call to get learnable parameters (w and b). These learnable parameters, once randomly set, will update over time as we learn. Learnable parameters are the first state_dict.

该PyTorch模型torch.nn.Modulemodel.parameters()来电,就能获得可以学习的参数(W和B)。这些可学习的参数一旦随机设置,就会随着我们的学习而随时间更新。可学习的参数是第一位的state_dict

The second state_dictis the optimizer state dict. You recall that the optimizer is used to improve our learnable parameters. But the optimizer state_dictis fixed. Nothing to learn in there.

第二个state_dict是优化器状态字典。您还记得优化器用于改进我们的可学习参数。但是优化器state_dict是固定的。里面没什么可学的。

Because state_dictobjects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

由于state_dict对象是 Python 字典,因此它们可以轻松保存、更新、更改和恢复,从而为 PyTorch 模型和优化器添加了大量模块化。

Let's create a super simple model to explain this:

让我们创建一个超级简单的模型来解释这一点:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

This code will output the following:

此代码将输出以下内容:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Note this is a minimal model. You may try to add stack of sequential

请注意,这是一个最小模型。您可以尝试添加顺序堆栈

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model's state_dict.

请注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓冲区(batchnorm 层)在模型的state_dict.

Non learnable things, belong to the optimizer object state_dict, which contains information about the optimizer's state, as well as the hyperparameters used.

不可学习的东西,属于优化器对象state_dict,它包含有关优化器状态的信息,以及使用的超参数。

The rest of the story is the same; in the inference phase (this is a phase when we use the model after training) for predicting; we do predict based on the parameters we learned. So for the inference, we just need to save the parameters model.state_dict().

故事的其余部分是相同的;在推理阶段(这是我们训练后使用模型的阶段)进行预测;我们确实根据我们学到的参数进行预测。所以对于推理,我们只需要保存参数model.state_dict()

torch.save(model.state_dict(), filepath)

And to use later model.load_state_dict(torch.load(filepath)) model.eval()

并在以后使用 model.load_state_dict(torch.load(filepath)) model.eval()

Note: Don't forget the last line model.eval()this is crucial after loading the model.

注意:不要忘记最后一行,model.eval()这在加载模型后至关重要。

Also don't try to save torch.save(model.parameters(), filepath). The model.parameters()is just the generator object.

也不要试图保存torch.save(model.parameters(), filepath). 这model.parameters()只是生成器对象。

On the other side, torch.save(model, filepath)saves the model object itself, but keep in mind the model doesn't have the optimizer's state_dict. Check the other excellent answer by @Jadiel de Armas to save the optimizer's state dict.

另一方面,torch.save(model, filepath)保存模型对象本身,但请记住模型没有优化器的state_dict. 检查@Jadiel de Armas 的另一个优秀答案以保存优化器的状态字典。

回答by harsh

A common PyTorch convention is to save models using either a .pt or .pth file extension.

一个常见的 PyTorch 约定是使用 .pt 或 .pth 文件扩展名保存模型。

Save/Load Entire Model Save:

保存/加载整个模型 保存:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Load:

加载:

Model class must be defined somewhere

模型类必须在某处定义

model = torch.load(PATH)
model.eval()

回答by Joy Mazumder

If you want to save the model and wants to resume the training later:

如果您想保存模型并想稍后继续训练:

Single GPU:Save:

单 GPU:保存:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Multiple GPU:Save

多 GPU:保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU