我正在使用RNN(GRUCell)进行Tensorflow估算器。 我使用zero_state初始化第一个状态,它需要固定的大小。 我的问题是我希望能够使用估计量对单个样本(batchsize = 1)进行预测。 加载序列化的估算器时,它抱怨我用于预测的批次大小与训练的批次大小不匹配。
如果我以不同的批处理大小重建估算器,则无法加载已序列化的估算器。
在估算器中是否有一种优雅的方法来使用zero_state? 我看到一些使用变量存储批次大小的解决方案,但使用feed_dict方法。我找不到如何在估计器的上下文中使用它。
这是估算器中我的简单测试RNN的核心:
cells = [ tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]
multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )
Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)
有人对此有线索吗?
编辑:
好吧,我尝试了很多有关此问题的事情。 现在,我尝试过滤从检查点加载的变量以删除“ H”,该“ H”用作循环单元的内部状态。为了进行预测,我可以将其全部保留为0。
到目前为止,我做到了: 首先,我定义一个钩子:
class RestoreHook(tf.train.SessionRunHook):
def __init__(self, init_fn):
self.init_fn = init_fn
def after_create_session(self, session, coord=None):
print("--------------->After create session.")
self.init_fn(session)
然后在我的model_fn中:
if mode == tf.estimator.ModeKeys.PREDICT:
logits = tf.nn.softmax(logits)
# Do not restore H as it's batch size might be different.
vlist = tf.contrib.framework.get_variables_to_restore()
vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
spec = tf.estimator.EstimatorSpec(mode=mode,
predictions = {
'logits': logits,
},
export_outputs={
'prediction': tf.estimator.export.PredictOutput( logits )
},
prediction_hooks=[RestoreHook(init_fn)])
我从https://github.com/tensorflow/tensorflow/issues/14713那里获得了这段代码
但是它还不起作用。看来它仍在尝试从文件中加载H ...我检查它不在vlist中。 我仍在寻找解决方案。
答案 0 :(得分:0)
您可以从其他张量example中获取批量大小
decoder_initial_state = cell.zero_state(array_ops.shape(attention_states)[0],
dtypes.float32).clone(cell_state=encoder_state)
答案 1 :(得分:0)
我找到了解决方法:
它起作用,因为这两个变量都将由估算器代码序列化并还原,因此不会产生任何抱怨。 缺点是在训练时(创建两个变量时)会知道查询批处理大小(在我的情况下为1)。