pytorch保存模型和导入模型,torch.load出错是什么原因?

Torch使用save都能成功,但是load不能成功是什么原因?
22.png

接下来我们看一下pytorch保存模型和导入模型
 
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))


 

3 个评论

同遇到这个问题,请问你是怎么解决的?
我是这样解决的,也不知道对不对,反正不报错了:把我的网络模型又写了一遍,也就是从声明网络class Net(nn.Module):这一行开始,一直到前向传播最后的return x。然后另起一行,写你的torch.load()那行代码
看看我的回答能不能解决你的问题

要回复文章请先登录注册