在纯TensorFlow中使用有状态Keras模型

时间:2017-12-14 10:17:16

标签: machine-learning tensorflow keras

我有一个有状态的RNN模型,其中包含几个在Keras中创建的GRU图层。

我现在必须从Java运行这个模型,所以我将模型转换为protobuf,并且我从Java TensorFlow加载它。

此模型必须是有状态的,因为功能将一次一步地输入。

据我所知,为了在TensorFlow模型中实现有状态,每次执行会话运行时,我必须以某种方式提供最后一个状态,并且运行将在执行后返回状态。

  1. 有没有办法在Keras模型中输出状态?
  2. 是否有一种更简单的方法可以使用TensorFlow使有状态的Keras模型正常工作?
  3. 非常感谢

2 个答案:

答案 0 :(得分:2)

另一种解决方案是使用keras模型的model.state_updates属性,并将其添加到session.run调用中。

下面是一个完整的示例,用两个lstms来说明此解决方案:

import tensorflow as tf

class SimpleLstmModel(tf.keras.Model):
    """ Simple lstm model with two lstm """
    def __init__(self, units=10, stateful=True):
        super(SimpleLstmModel, self).__init__()
        self.lstm_0 = tf.keras.layers.LSTM(units=units, stateful=stateful, return_sequences=True)
        self.lstm_1 = tf.keras.layers.LSTM(units=units, stateful=stateful, return_sequences=True)

    def call(self, inputs):
        """
        :param inputs: [batch_size, seq_len, 1]
        :return: output tensor
        """
        x = self.lstm_0(inputs)
        x = self.lstm_1(x)
        return x

def main():
    model = SimpleLstmModel(units=1, stateful=True)
    x = tf.placeholder(shape=[1, 1, 1], dtype=tf.float32)
    output = model(x)
    sess = tf.Session()

    sess.run(tf.initialize_all_variables())

    res_at_step_1, _ = sess.run([output, model.state_updates], feed_dict={x: [[[0.1]]]})
    print(res_at_step_1)
    res_at_step_2, _ = sess.run([output, model.state_updates], feed_dict={x: [[[0.1]]]})
    print(res_at_step_2)




if __name__ == "__main__":
    main()

哪个会产生以下输出:

[[[0.00168626]]] [[[0.00434444]]]

并显示批处理之间保留了lstm状态。 如果将有状态设置为False,则输出将变为:

[[[0.00033928]]] [[[0.00033928]]]

显示状态未被重用。

答案 1 :(得分:1)

好的,所以我设法解决了这个问题!

对我来说,有效的是不仅为标准的输出而且为州的张量创建输出。

在Keras模型中,可以通过以下方式找到状态张量:

model.updates

其中包含以下内容:

[(<tf.Variable 'gru_1_1/Variable:0' shape=(1, 70) dtype=float32_ref>,
  <tf.Tensor 'gru_1_1/while/Exit_2:0' shape=(1, 70) dtype=float32>),
 (<tf.Variable 'gru_2_1/Variable:0' shape=(1, 70) dtype=float32_ref>,
  <tf.Tensor 'gru_2_1/while/Exit_2:0' shape=(1, 70) dtype=float32>),
 (<tf.Variable 'gru_3_1/Variable:0' shape=(1, 4) dtype=float32_ref>,
  <tf.Tensor 'gru_3_1/while/Exit_2:0' shape=(1, 4) dtype=float32>)]

&#39;变量&#39;用于输入状态,以及退出&#39;对于新州的输出。 所以我从退出&#39;退出&#39;创建了tf.identity。张量。我给了他们有意义的名字,例如:

tf.identity(state_variables[j], name='state'+str(j))

state_variables只包含&#39;退出&#39;张量

然后使用输入变量(例如gru_1_1/Variable:0)从TensorFlow提供模型状态,并使用&#39;退出&#39;在每个时间步长喂食模型后,使用张量提取新状态