在使用Pytorch时提前分配显存

Pytorch与Tensorflow在程序运行时的一个不同点是:tensorflow会在程序刚开始运行时就自动占掉所有可用显存;而pytorch会根据当前情况实时调整显存占用。在多人共用GPU训练神经网络的时候,往往会出现这样的情况:pytorch程序运行若干个epoch之后报错out of memory——也就是被人挤掉了。
这篇文章介绍对付这种情况比较hack的一种方法。
首先要说明这是 针对特殊情况的权宜之计,仅适用于特定场景。
具体实现的思路:在程序开始运行的时候在GPU上声明一个巨大的变量,占掉所有可用显存,而后当神经网络需要使用显存的时候,释放一些显存出来。
代码如下:


实现起来非常简单,只需声明完变量再将其删除即可,被删除的变量所占据的显存不会立即被释放。
当后续代码需要使用显存的时候,系统会自动地释放出这部分显存。