使用带有急切执行的tf.nn.raw_rnn进行RNN预测和训练

时间:2019-04-04 14:17:26

标签: tensorflow recurrent-neural-network eager-execution

我目前正在尝试启用启用急切执行的顺序变体自动编码器的变体。粗略地说,我想在对一些顺序数据进行建模时,从潜在变量中求助于“帮助”。由于神经网络模型不是标准的,因此我制作了自己的RNN单元以及loop_fn函数。然后,我将这两个传递给tf.nn.raw_rnn,即我有tf.nn.raw_rnn(RNN_cell,loop_fn)。

现在,问题是,我不知道要进行任何进一步的移动,既不进行预测也不进行网络训练。

让我陷入困境的是,tf.nn.raw_rnn(RNN_cell,loop_fn)产生数值,而不是模型(例如,像tf.keras.Model类的实例化)。那么,我应该如何处理这些数字?换句话说,如何将tf.nn.raw_rnn(RNN_cell,loop_fn)视为模型,该模型可以消耗新的输入数据并产生输出?

我看过一些有关RNN的教程博客。但是,它们都没有完全使用tf.nn.raw_rnn(急切地执行)。有人有提示吗?

import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()

from tensorflow.keras.models import Model

#defining class containing each module of model
class ModuleBox(tf.keras.Model):
    def __init__(self, latent_dim, intermediate_dim):
        super(ModuleBox, self).__init__()
        self.latent_dim = latent_dim

        self.inference_net = Model(...)

        self.generative_net = Model(...)

class PreSSM(tf.contrib.rnn.RNNCell):
    def __init__(self, latent_dim = 4, intermediate_dim = 50):
        self.input_dim = latent_dim + 4 #note for toy problem

        module = ModuleBox(latent_dim, intermediate_dim)

        self.inference_net = module.inference_net

        self.generative_net = module.generative_net

    @property
    def state_size(self):
        return latent_dim

    @property
    def output_size(self):
        return 2 #(x,y) coordinate

    def __call__(self, inputs, state):
        next_state = self.inference_net(inputs)[-1]
        output = self.generative_net(next_state)
        return output, next_state

#the loop_fn function, needed by tf.nn.raw_rnn
def loop_fn(time, cell_output, cell_state, loop_state):
    emit_output = cell_output # ==None for time == 0
    if cell_output is None: # when time == 0
        next_cell_state = init_state
        emit_output = tf.zeros([output_dim])
    else :
        emit_output = cell_output
        next_cell_state = cell_state

    elements_finished = (time >= seq_length)
    finished = tf.reduce_all(elements_finished)

    if finished :
        next_input = tf.zeros(shape=(output_dim), dtype=tf.float32)
    else :
        next_input = tf.concat([inputs_ta.read(time), next_cell_state],-1)

    next_loop_state = None
    return (elements_finished, next_input, next_cell_state, emit_output, 
          next_loop_state)

#instatiation of RNN_cell
cell = PreSSM()

#the outputs
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()

0 个答案:

没有答案