pytorch中如果使用DataParallel,那么保存的模型key值前面会多处’modules.’,这样如果训练的时候使用的是多GPU,而测试的时候使用的是单GPU,模型载入就会出现问题。
一个解决方法是测试的时候强行DataParallel,但是有时候情况较为复杂,可以使用如下的方法:
(参考来源:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/13)
original saved file with DataParallel1
state_dict = torch.load(‘myfile.pth.tar’)
create new OrderedDict that does not contain module.1
2
3
4
5from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v
load params1
model.load_state_dict(new_state_dict)
简而言之,就是重新创建一个OrderedDict,然后将它载入模型就行了。