我在Tensorflow中使用while_loop以便遍历张量并提取给定维度上的特定切片。对于每一步,我需要使用解码器RNN生成输出符号序列。我正在使用 tf.contrib.seq2seq 中提供的代码,尤其是tf.contrib.seq2seq.dynamic_decode。该代码类似于以下内容:
def decoder_condition(i, data, source_seq_len, ta_outputs):
return tf.less(i, max_loop_len)
def decode_body(i, data, source_seq_len, ta_outputs):
curr_data = data[:, i, :]
curr_source_seq_len = source_seq_len[:, i, :]
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
2 * self.opt["encoder_rnn_h_size"],
curr_data,
memory_sequence_length=curr_source_seq_len
)
cell = GRUCell(num_units)
cell = AttentionWrapper(cell, attention_mechanism)
# ... other code that initialises all the variables required
# for the RNN decoder
outputs = tf.contrib.seq2seq.dynamic_decode(
decoder,
maximum_iterations=self.opt["max_sys_seq_len"],
swap_memory=True
)
with tf.control_dependencies([outputs)]:
ta_outputs = ta_outputs.write(i, outputs)
return i+1, data, ta_outputs
loop_index = tf.constant(0)
gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
outputs = tf.while_loop(
decoder_condition,
decoder_body,
loop_vars=[
loop_index,
data,
data_source_len,
ta_outputs
],
swap_memory=True,
back_prop=True,
parallel_iterations=1
)
如您所见,我创建了不同的对象,这些对象具体取决于当前步骤 i 的输入。我在当前变量作用域中使用tf.AUTO_REUSE
的方式是即使创建不同的对象也可以重用变量。不幸的是,我的解码器似乎训练不正确,因为它不断生成错误的值。我已经检查了解码器RNN的输入数据,一切都正确。我怀疑就TensorFlow如何管理TensorArray和while_loop而言,我做得不好。
所以我的主要问题是:
谢谢!
更新: 不知道为什么,但是似乎存在一个未解决的问题,该问题与在while循环中调用自定义操作的可能性有关,如下所述:https://github.com/tensorflow/tensorflow/issues/13616。不幸的是,我对TensorFlow的内部知识了解不足,无法判断它是否与此完全相关。
更新2: 我解决了使用PyTorch :)
答案 0 :(得分:0)
(1)是
(2)是的,只需使用循环索引对张量进行切片
(3)在普通情况下无需设置backprop = False
(4)使用ML模型进行的常规操作(玩具数据集,单独的测试零件等)
重新更新2,尝试使用急切执行或tf.contrib.autograph;两者都应该让您用纯python编写while循环。