我正在使用TensorFlow的python API来训练LSTM的变体。
为此,我使用tf.while_loop
函数迭代时间步骤。
在cpu上运行我的脚本时,它不会产生任何错误消息,但是由于以下原因导致gpu python崩溃:
...tensorflow/tensorflow/core/framework/tensor.cc:885] Check failed: nullptr != b.buf_ (nullptr vs. 00...)
我的代码中导致此失败的部分(当评论它时,它起作用)位于while循环的主体中:
...
h_gathered = h_ta.gather(tf.range(time))
h_gathered = tf.transpose(h_gathered, [1, 0, 2])
syn_t = self.syntactic_weights_ta.read(time)[:, :time]
syn_t = tf.expand_dims(syn_t, 1)
syn_state_t = tf.squeeze(tf.tanh(tf.matmul(syn_t, h_gathered)), 1)
...
其中time
为零并且在每个步骤后递增,h_ta
是TensorArray
h_ta = tf.TensorArray(
dtype=dtype,
size=max_seq_len,
clear_after_read=False,
element_shape=[batch_size, num_hidden],
tensor_array_name="fw_output")
和self.syntactic_weights_ta
也是TensorArray
self.syntactic_weights_ta = tf.TensorArray(
dtype=dtype,
size=max_seq_len,
tensor_array_name="fw_syntactic_weights")
self.syntactic_weights_ta = self.syntactic_weights_ta.unstack(syntactic_weights)
我在代码片段中尝试实现的内容基本上是过去输出的加权和,存储在h_ta
中。
最后,我使用tf.train.AdamOptimizer
训练网络。
我再次测试了脚本,但是这次将while循环中的swap_memory
参数设置为False
并且它也可以在GPU上运行,但我真的很想知道它为什么会这样做不适用于swap_memory=True
。
答案 0 :(得分:0)
这看起来像TensorArray的张量存储机制与swap_memory = True时由while_loop执行的分配魔法交互的方式中的错误。
你能在TF' github上打开一个问题吗?还请包括:
并在这里回复github问题的链接?