尝试还原张量流模型时出错

时间:2019-01-03 13:17:31

标签: python tensorflow

以下代码清单:

    class Dense:
    def __init__(self, hidden_dim, output_dim):
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        with tf.variable_scope('Dense', reuse=tf.AUTO_REUSE):
            self.w = tf.get_variable('Dense_weight_matrix', shape=[hidden_dim, output_dim], dtype=tf.float32, initializer=tf.truncated_normal_initializer())
            self.b = tf.get_variable('Dense_bias_vector', shape =[output_dim], dtype=tf.float32, initializer=tf.constant_initializer(0))

    def predict(self, x):
        return tf.matmul(x, self.w) + self.b

class Encoder:
    def __init__(self, num_features, seq_len, hidden_dim, num_layers):
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.input = [tf.placeholder(tf.float32, shape=(None, num_features), name="Input_{}".format(t)) for t in range(seq_len)]
        with tf.variable_scope('Encoder_LSTMCell'):
            cells = []
            for i in range(num_layers):
                with tf.variable_scope('RNN_{}'.format(i)):
                    cells.append(tf.contrib.rnn.LSTMCell(hidden_dim))
            self.cell = tf.contrib.rnn.MultiRNNCell(cells)
            _, self.state  = tf.contrib.rnn.static_rnn(self.cell, self.input, dtype=tf.float32)

class Decoder:
    def __init__(self, seq_len, hidden_dim, output_dim, num_layers, scope):
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.target = [tf.placeholder(tf.float32, shape=(None, output_dim), name="Target_".format(t)) for t in range(seq_len)]
        self.input = [tf.zeros_like(self.target[0], dtype=tf.float32, name='GO')] + self.target[:-1]
        self.dense = self.dense = Dense(hidden_dim, output_dim)
        self.scope = scope
        with tf.variable_scope('Decoder_LSTMCell'):
            cells = []
            for i in range(num_layers):
                with tf.variable_scope('RNN_{}'.format(i)):
                    cells.append(tf.contrib.rnn.LSTMCell(hidden_dim))
            self.cell = tf.contrib.rnn.MultiRNNCell(cells)

    def predict(self, init_state, feed_prev=False):
        with self.scope:
            state = init_state
            outputs = []
            prev = None
            for i, inp in enumerate(self.input):
                if feed_prev and prev is not None:
                    with variable_scope.variable_scope("predict_func", reuse=True):
                        inp = self.dense.predict(prev)
                if i > 0:
                    variable_scope.get_variable_scope().reuse_variables()
                output, state = self.cell(inp, state)
                outputs.append(output)
                if feed_prev:
                    prev = output
            outputs = [self.dense.predict(o) for o in outputs]
            return outputs

class Seq2seq:
    def __init__(self, input_seq_len, output_seq_len, num_features, hidden_dim, output_dim, num_layers):
        tf.reset_default_graph()
        self.scope = tf.variable_scope("Seq2seq")
        with self.scope:
            self.encoder = Encoder(num_features, input_seq_len, hidden_dim, num_layers)
            self.decoder = Decoder(output_seq_len, hidden_dim, output_dim, num_layers, self.scope)


    def loss(self, alpha=0.03):
        with tf.variable_scope('Loss'):
            output_loss = 0
            for y, t in zip(self.predict(), self.decoder.target):
                output_loss += tf.reduce_mean(tf.pow(y-t, 2))
            reg_loss = 0
            for var in tf.trainable_variables():
                if 'weight' in var.name:
                    reg_loss += tf.reduce_mean(tf.nn.l2_loss(var))
        return output_loss + alpha * reg_loss

    def optimize(self, learning_rate = 0.01, alpha = 0.03, gradient_clip=2.5):
        global_step = tf.Variable(initial_value=0,
                                 name="global_step",
                                 trainable=False,
                                 collections = [tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

        with tf.variable_scope('Optimizer'):
            train_op = tf.contrib.layers.optimize_loss(loss=self.loss(alpha), learning_rate=learning_rate, global_step=global_step, optimizer='Adam', clip_gradients=gradient_clip)
        return train_op

    def predict(self, feed_prev=False):
        enc_state = self.encoder.state
        preds = self.decoder.predict(enc_state, feed_prev)
        return preds

试图从以下仓库中重构代码: https://github.com/aaxwaz/Multivariate-Time-Series-forecast-using-seq2seq-in-TensorFlow/blob/master/build_model_basic.py

创建和训练模型效果很好:

model = Seq2seq(input_seq_len=input_seq_len, output_seq_len=output_seq_len, num_features=num_features, hidden_dim=64, output_dim=output_dim, num_layers=2)
loss = model.loss()
train_op = model.optimize()

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for i in range(total_iterations):
        batch_input, batch_output = generate_train_samples(batch_size=batch_size)

        feed_dict = {model.encoder.input[t]:batch_input[:,t].reshape(-1, num_features) for t in range(input_seq_len)}
        feed_dict.update({model.decoder.target[t]:batch_output[:,t].reshape(-1, output_dim) for t in range(output_seq_len)})

        _, loss_t = sess.run([train_op, loss], feed_dict)
        print(loss_t)

    saver = tf.train.Saver()
    save_path = saver.save(sess, 'refactored.ckpt'))

尝试还原保存的模型时:

    model = Seq2seq(input_seq_len=input_seq_len, output_seq_len=output_seq_len, num_features=num_features, hidden_dim=64, output_dim=output_dim, num_layers=2)
    pred_op = model.predict(feed_prev=True)


    init = tf.global_variables_initializer()


    with tf.Session() as sess:
        sess.run(init)
        saver = tf.train.Saver()  
        saver.restore(sess, save_path)

我收到以下错误:

NotFoundError:从检查点还原失败。这很可能是由于检查点缺少变量名或其他图形键。请确保您没有更改基于检查点的预期图形。原始错误:

在检查点中找不到键Seq2seq / multi_rnn_cell / cell_0 / lstm_cell / bias      [{ :0 /设备:CPU:0“](_ arg_save / Const_0_0,保存/还原V2 / tensor_names,保存/还原V2 / shape_and_slices)]]

0 个答案:

没有答案
相关问题