解决pytorch中DataParallel后模型参数出现问题的方法

pytorch中如果使用DataParallel,那么保存的模型key值前面会多处’modules.’,这样如果训练的时候使用的是多GPU,而测试的时候使用的是单GPU,模型载入就会出现问题。
一个解决方法是测试的时候强行DataParallel,但是有时候情况较为复杂,可以使用如下的方法:
(参考来源:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/13)

original saved file with DataParallel

1
state_dict = torch.load(‘myfile.pth.tar’)

create new OrderedDict that does not contain module.

1
2
3
4
5
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v

load params

1
model.load_state_dict(new_state_dict)

简而言之,就是重新创建一个OrderedDict,然后将它载入模型就行了。