我目前正在尝试启用启用急切执行的顺序变体自动编码器的变体。粗略地说,我想在对一些顺序数据进行建模时,从潜在变量中求助于“帮助”。由于神经网络模型不是标准的,因此我制作了自己的RNN单元以及loop_fn函数。然后,我将这两个传递给tf.nn.raw_rnn,即我有tf.nn.raw_rnn(RNN_cell,loop_fn)。
现在,问题是,我不知道要进行任何进一步的移动,既不进行预测也不进行网络训练。
让我陷入困境的是,tf.nn.raw_rnn(RNN_cell,loop_fn)产生数值,而不是模型(例如,像tf.keras.Model类的实例化)。那么,我应该如何处理这些数字?换句话说,如何将tf.nn.raw_rnn(RNN_cell,loop_fn)视为模型,该模型可以消耗新的输入数据并产生输出?
我看过一些有关RNN的教程博客。但是,它们都没有完全使用tf.nn.raw_rnn(急切地执行)。有人有提示吗?
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
from tensorflow.keras.models import Model
#defining class containing each module of model
class ModuleBox(tf.keras.Model):
def __init__(self, latent_dim, intermediate_dim):
super(ModuleBox, self).__init__()
self.latent_dim = latent_dim
self.inference_net = Model(...)
self.generative_net = Model(...)
class PreSSM(tf.contrib.rnn.RNNCell):
def __init__(self, latent_dim = 4, intermediate_dim = 50):
self.input_dim = latent_dim + 4 #note for toy problem
module = ModuleBox(latent_dim, intermediate_dim)
self.inference_net = module.inference_net
self.generative_net = module.generative_net
@property
def state_size(self):
return latent_dim
@property
def output_size(self):
return 2 #(x,y) coordinate
def __call__(self, inputs, state):
next_state = self.inference_net(inputs)[-1]
output = self.generative_net(next_state)
return output, next_state
#the loop_fn function, needed by tf.nn.raw_rnn
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # ==None for time == 0
if cell_output is None: # when time == 0
next_cell_state = init_state
emit_output = tf.zeros([output_dim])
else :
emit_output = cell_output
next_cell_state = cell_state
elements_finished = (time >= seq_length)
finished = tf.reduce_all(elements_finished)
if finished :
next_input = tf.zeros(shape=(output_dim), dtype=tf.float32)
else :
next_input = tf.concat([inputs_ta.read(time), next_cell_state],-1)
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output,
next_loop_state)
#instatiation of RNN_cell
cell = PreSSM()
#the outputs
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()