Pytorch在加载模型参数时指定设备

将参数加载到CPU

在使用torch.load()方法加载函数的时候,会遇到CUDA OUT OF MEMORY的问题,这是由于训练的时候是在GPU上进行训练,因此在加载的时候默认也是加载到GPU上。
根据torch.load的官方文档
torch.load('my_file.pt', map_location=lambda storage, loc: storage)语句可以将参数保存到CPU上。

更多操作

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt') as f:
buffer = io.BytesIO(f.read())
>>> torch.load(buffer)