Tensorflow:如何在具有不同批次大小的估计量中使用RNN初始状态进行训练和测试?

时间:2018-09-10 08:39:23

标签: python tensorflow rnn

我正在使用RNN(GRUCell)进行Tensorflow估算器。 我使用zero_state初始化第一个状态,它需要固定的大小。 我的问题是我希望能够使用估计量对单个样本(batchsize = 1)进行预测。 加载序列化的估算器时,它抱怨我用于预测的批次大小与训练的批次大小不匹配。

如果我以不同的批处理大小重建估算器,则无法加载已序列化的估算器。

在估算器中是否有一种优雅的方法来使用zero_state? 我看到一些使用变量存储批次大小的解决方案,但使用feed_dict方法。我找不到如何在估计器的上下文中使用它。

这是估算器中我的简单测试RNN的核心:

cells = [  tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]


multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )

Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)

有人对此有线索吗?

编辑:

好吧,我尝试了很多有关此问题的事情。 现在,我尝试过滤从检查点加载的变量以删除“ H”,该“ H”用作循环单元的内部状态。为了进行预测,我可以将其全部保留为0。

到目前为止,我做到了: 首先,我定义一个钩子:

class RestoreHook(tf.train.SessionRunHook):
    def __init__(self, init_fn):
        self.init_fn = init_fn

    def after_create_session(self, session, coord=None):
        print("--------------->After create session.")
        self.init_fn(session)

然后在我的model_fn中:

if mode == tf.estimator.ModeKeys.PREDICT:
        logits = tf.nn.softmax(logits)

        # Do not restore H as it's batch size might be different.
        vlist = tf.contrib.framework.get_variables_to_restore()
        vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          predictions = {
                                              'logits': logits,
                                          },
                                          export_outputs={
                                              'prediction': tf.estimator.export.PredictOutput( logits )
                                          },
                                          prediction_hooks=[RestoreHook(init_fn)])

我从https://github.com/tensorflow/tensorflow/issues/14713那里获得了这段代码

但是它还不起作用。看来它仍在尝试从文件中加载H ...我检查它不在vlist中。 我仍在寻找解决方案。

2 个答案:

答案 0 :(得分:0)

您可以从其他张量example中获取批量大小

decoder_initial_state = cell.zero_state(array_ops.shape(attention_states)[0], dtypes.float32).clone(cell_state=encoder_state)

答案 1 :(得分:0)

我找到了解决方法:

  • 我为batchsize = 64和batchsize = 1的初始状态创建变量。
  • 在训练中,我使用第一个初始化RNN。
  • 在预测时间,我使用第二个。

它起作用,因为这两个变量都将由估算器代码序列化并还原,因此不会产生任何抱怨。 缺点是在训练时(创建两个变量时)会知道查询批处理大小(在我的情况下为1)。