训练tensorflow GAN时的ResourceExhaustedError

时间:2018-06-14 18:47:29

标签: python-3.x tensorflow generative-adversarial-network

我尝试在我的本地cpu上训练我的tensorflow模型并使用AWS(gpu),两次都遇到了ResourceExhaustedError。每个变量似乎占用2GB的空间,这对我来说似乎过多。优化行

会抛出错误
trainer_dis = tf.train.AdamOptimizer(learning_rate = 1e-4).minimize(dis_loss, var_list = d_vars)

我不确定我的代码是否有问题,或者我用来训练的两台机器的计算能力是否足够。如果需要更多信息,我将我的代码链接到下面的github中。

我的代码功能类似于典型的GAN。具体而言,生成器接收文本输入,该输入用于使用GLoVe嵌入来创建固定长度的向量。然后我将其连接到相同大小的噪声矢量,然后使用传统的完全连接的神经网络进行上采样,然后使用一些反卷积层来创建图像。鉴别器将图像和文本嵌入作为输入,并使用完全连接的层对文本嵌入进行上采样。在将图像和文本嵌入连接到通道(第3个)维度之后,它通过典型的卷积网络运行连接张量,以返回概率分数,该分数对应于图像是真实的还是假的。

当上采样完全连接层传递到最小化操作时,将抛出OOM错误。

W0 = tf.get_variable('dis11', shape = (init_embedding.shape[1], 76800), dtype = tf.float32, initializer = tf.truncated_normal_initializer)
        B0 = tf.get_variable('dis2', shape = (76800,), dtype = tf.float32, initializer = tf.constant_initializer(0.0))
        Y0 = tf.nn.relu(tf.add(tf.matmul(init_embedding, W0), B0))

我在网上看到这个错误可能是由批量大的问题引起的。然而,由于我使用随机梯度下降(批量大小= 1图像/字幕对)进行训练,因此对我来说情况并非如此。

非常感谢任何帮助!

完整代码:https://github.com/vdopp234/Text2Image

0 个答案:

没有答案
相关问题