不能将Tensorflow RNNCell子类化为自定义循环体系结构

时间:2017-09-22 09:48:16

标签: python tensorflow rnn

我正在尝试将RNNCell子类化为自定义循环模型,但是我收到以下错误

File "python3.6/site-packages/tensorflow/python/ops/rnn.py", line 508, in dynamic_rnn
raise TypeError("cell must be an instance of RNNCell")

子类具有以下结构

import tensorflow as tf

class CSCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, mem_size=64, word_size=4, batch_size=16, hidden_size=16, output_size=4):
    super(CSCell, self).__init__()
@property
def output_size(self):
    return self._output_size
@property
def state_size(self):
    return self._total_state_vector_size

def zero_state(self, batch_size, dtype):
    #zero vectors here 
    return zero_tensor

#(output, new_state) = self.__call__(inputs,state)
def __call__(self, inputs, state):
#implementation
    return (output, new_state)

在另一个文件中,我使用它如下:

cscell_tf_core = cscell_tf.CSCell(FLAGS.mem_size,
FLAGS.word_size,
FLAGS.batch_size,
FLAGS.hidden_size,
output_size=64)

output_sequence, _ = tf.nn.dynamic_rnn(
cell=cscell_tf_core,
inputs=dataset_tensors.observations,
time_major=True)

我的猜测是,dynamic_rnn无法识别CSCell子类,但我无法理解其原因。我使用Tensorflow 1.2版。我被卡住,任何方向都非常感谢。

0 个答案:

没有答案