在TensorFlow上实现顺序变分自动编码器(状态空间模型)

时间:2019-04-10 14:31:28

标签: tensorflow keras recurrent-neural-network state-space tf.keras

我目前正在尝试在顺序设置中实现变体自动编码器的版本。我在TensorFlow上以热切的执行模式工作。

作为问题设置,我有两个变量序列:动作(2D)和观察(2D)。假定动作会影响观察。目的是对动作进行建模(恢复)观察序列。为此,我们从潜在变量(隐藏)“寻求帮助”,这与VAE相似,但采用顺序方式。

下面是问题的生成流程。 o,a和s分别是观察变量,动作变量和潜变量。

The generative flow of the modeling problem

就像VAE一样,神经网络用于参数化所涉及变量的分布。在这里,我假设所有变量都遵循带有完全对角协方差矩阵(无协方差)的“多变量正态”。

这里涉及三个神经网络:推理网络,过渡网络和生成网络。它们每个都发出相应变量的平均值对数方差向量。下面是所有这些网的图片: Three nets relation

图片的议程: 答:来自推理网的潜在变量均值 B:来自推理网的潜在变量的对数方差 C:来自过渡网的潜变量平均值 D:转换网的潜在变量的对数方差 E:预计观测值的平均值 F:预测观测值的对数方差

我们要减少的损失为负ELBO。尽管ELBO本身等于真实观测的对数似然,但在给定E和F A-B和C-D之间的Kullback-Leibler距离的情况下。

由于该问题同时导致非标准RNN单元和输入输出流,因此我创建了自己的RNNcell,然后将其传递给tf.nn.raw_rnn API。

下面是我的代码实现:

from __future__ import absolute_import, division, print_function
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()

import os
import numpy as np
import math

from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model

#training data
inputs #shape (time_step, batch_size, input_depth) = (20,1000,4)

#global configuration variables
max_time = 20
batch_size = 1000
latent_dim = 4

#initial state
init_state = tf.zeros([batch_size, latent_dim])

#sampling and reparameterizing function
def sampling(args):
    mean, logvar = args
    batch = batch_size
    dim = latent_dim
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = tf.random_normal(shape=(batch, dim))
    return mean + tf.exp(0.5 * logvar) * epsilon

#class of the model, in fact it is also an RNN cell
class SSM(tf.keras.Model):
    def __init__(self, latent_dim = 4, observation_dim = 2):
        super(SSM, self).__init__()
        self.latent_dim = latent_dim
        self.observation_dim = observation_dim
        self.input_dim = (self.latent_dim + self.observation_dim + 2) + (self.latent_dim + 2) # input of inference net and transition net


        #inference net
        inference_input = Input(shape=(self.latent_dim + self.observation_dim + 2,), name='inference_input')
        layer_1 = Dense(30, activation='tanh')(inference_input)
        layer_2 = Dense(30, activation='tanh')(layer_1)
        inference_mean = Dense(latent_dim, name='inference_mean')(layer_2)
        inference_logvar = Dense(latent_dim, name='inference_logvar')(layer_2)        
        s = Lambda(sampling, output_shape=(latent_dim,), name='s')([inference_mean, inference_logvar])
        self.inference_net = Model(inference_input, [inference_mean, inference_logvar, s], name='inference_net')

        #transition net
        trans_input = Input(shape=(self.latent_dim + 2,), name='transition_net')
        layer_1a = Dense(20, activation='tanh')(trans_input)
        layer_2a = Dense(20, activation='tanh')(layer_1a)
        trans_mean = Dense(latent_dim, name='trans_mean')(layer_2a)
        trans_logvar = Dense(latent_dim, name='trans_logvar')(layer_2a)
        self.transition_net = Model(trans_input, [trans_mean, trans_logvar], name='transition_net')

        #generative net
        latent_inputs = Input(shape=(self.latent_dim,), name='s_sampling')
        layer_3 = Dense(10, activation='tanh')(latent_inputs)
        layer_4 = Dense(10, activation='tanh')(layer_3)
        obs_mean = Dense(observation_dim, name='observation_mean')(layer_4)
        obs_logvar = Dense(observation_dim, name='observation_logvar')(layer_4)
        self.generative_net = Model(latent_inputs, [obs_mean, obs_logvar], name='generative_net')

    @property
    def state_size(self):
        return self.latent_dim

    @property
    def output_size(self):
        return (2 * self.latent_dim) + (2 * self.latent_dim) + (2 * self.observation_dim) #mean&logvar of latent in infer, trans & observation in generative

    @property
    def zero_state(self):
        return init_state #global variable we have defined

    def __call__(self, inputs, state):
        #next state is the sampled latent variables from inference net
        next_state = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[2]

        #mean and logvar of latent variables, inference net version
        #note that the input of RNN cell is 14 dimension, the first 8 = latent_dim + observation_dim + 2 is for input inference net
        #and the remaining 6 (without observation) is for transition net
        infer_mean = self.inference_net(inputs[:,:(self.latent_dim + self.observationdim + 2)])[0]
        infer_logvar = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[1]

        #mean and logvar of latent variables, transition net version
        trans_mean = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[0]
        trans_logvar = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[1]

        #mean and logvar of observation
        obs_mean = self.generative_net(next_state)[0]
        obs_logvar = self.generative_net(next_state)[1]

        #output of RNN cell are concatenation of all
        output = tf.concat([infer_mean, infer_logvar, trans_mean, trans_logvar, obs_mean, obs_logvar], -1)
        return output, next_state

#define a class with instant having method called loop_fn (needed for tf.nn.raw_rnn)
class SuperLoop:
    def __init__(self, inputs, output_dim = 20): # 20 = 4*latent_dim + 2*observation_dim
        inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
        inputs_ta = inputs_ta.unstack(inputs) #ini datanya
        self.inputs_ta = inputs_ta
        self.output_dim = output_dim
        self.output_ta = tf.TensorArray(dtype=tf.float32, size=max_time) #for saving the states

    def loop_fn(self,time, cell_output, cell_state, loop_state):
        emit_output = cell_output # ==None for time == 0
        if cell_output is None: # when time == 0
            next_cell_state = init_state
            emit_output = tf.zeros([self.output_dim])
            next_loop_state = self.output_ta

        else :
            emit_output = cell_output
            next_cell_state = cell_state
            #saving the sampled latent variables
            next_loop_state = loop_state.write(time-1, next_cell_state)

        elements_finished = (time >= max_time)
        finished = tf.reduce_all(elements_finished)

        if finished :
            next_input = tf.zeros(shape=(self.output_dim), dtype=tf.float32)
        else :
            #cell's next input
            next_input = tf.concat([self.inputs_ta.read(time), next_cell_state, self.inputs_ta.read(time)[:,:2], next_cell_state], -1)

        return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)


def SSM_model(inputs, RNN_cell = SSM(), output_dim = 20):
    superloop = SuperLoop(inputs, output_dim)
    outputs_ta, final_state, final_loop_state = tf.nn.raw_rnn(RNN_cell, superloop.loop_fn)
    #outputs_ta is stilltensor array, hence need to be stacked
    obs = outputs_ta.stack()
    obs = tf.where(tf.is_nan(obs), tf.zeros_like(obs), obs)
    #final loop state contains the sampled latent variables
    latent = final_loop_state.stack()
    latent = tf.where(tf.is_nan(latent), tf.zeros_like(latent), latent)
    observation_latent = [obs, latent]
    return observation_latent

#cell == model instant
model = SSM()

#Define the loss: negative of ELBO, ElBO = log p(o|s) - KL(infer|trans)
def KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim = 4):
    var_gamma = tf.exp(trans_logvar)
    var_phi = tf.exp(infer_logvar)
    sgm_gamma = tf.matrix_diag(var_gamma) #shape (20,1000,4,4)
    sgm_phi = tf.matrix_diag(var_phi)

    eps = 10e-5 #to ensure nonsingularity

    '''
    analytic expression of KL divergence value between 2 multivariate normal
    '''

    KL_term = 0.5 * (term_1 + term_2 - term_3 + term_4)

    return KL_term

#log of probability p(o|s)
def log_prob(value, obs_mean, obs_logvar, observation_dim = 2):
    var_theta = tf.exp(obs_logvar)
    sgm_theta = tf.matrix_diag(var_theta)

    eps = 10e-5

    '''
    first compute the likelihood of p(value) ~ Multivariate Normal(obs_mean, sgm_theta)
    then compute the log of it = logprob
    '''

    return logprob

def loss(model, inputs, latent_dim = 4, observation_dim = 2):
    outputs = SSM_model(inputs, model)[0] #only need the output of net to compute loss
    infer_mean = outputs[:,:,:latent_dim]
    infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
    trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)]
    trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
    obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + observation_dim)]
    obs_logvar = outputs[:,:,((4 * latent_dim) + observation_dim):]

    #logprob term
    value = inputs[:,:,2:4] #observation location in inputs
    logprob = log_prob(value, obs_mean, obs_logvar, output_obs_dim)
    logprob = tf.reduce_mean(logprob)

    #KL term
    KL_term = KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim)
    KL_term = tf.reduce_mean(KL_term)

    return KL_term - logprob

#computing gradient function
def compute_gradients(model, x):
  with tf.GradientTape() as tape:
    loss_value = loss(model, x)
  return tape.gradient(loss_value, model.trainable_variables), loss_value

compute_gradients(model, inputs)

最后一行导致过渡网和生成网的零梯度带生成,因此我无法继续进行下去。有谁知道过渡和生成网的梯度为何为零的线索?我猜想,我创建模型的代码还是错误的。但是我不知道要改进它。

0 个答案:

没有答案