具有元组输入的自定义TensorFlow RNN单元

时间:2017-11-27 21:32:22

标签: python tensorflow rnn

我正在尝试在TensorFlow中创建一个自定义RNN单元,它接受一个元组作为输入,但我遇到了父类 d1 d2 d3 d4 t1 x y z x t2 etc. t3 t4 ... 要求输入是二维的问题:

BasicLSTMCell

我如何解决这个限制?我无法添加逻辑来处理# Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) 方法中的元组,因为执行永远不会到达方法 - 维度检查会引发错误。

1 个答案:

答案 0 :(得分:1)

我实际上也发现了这个问题。 tensorflow平台中存在一个错误。您可以通过更改recurrent.py文件中的get_step_input_shape函数来解决。只需在此行的末尾添加[0]:nest.map_structure(get_input_spec,input_shape))