在tensorflow 1.1.0中保存变量时发出警告

时间:2017-05-15 18:06:22

标签: python tensorflow

我正在从tensorflow版本0.12升级到版本1.1,并且我收到警告,当我保存我训练过的变量时我不明白。我在Windows上运行python 3.5.2。 Tensorflow是通过pip安装的,我只在CPU上运行。

我可以使用以下代码重现警告:

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

batch_size = 1
timesteps = 1
rnn_size = 4
n_channels = 1

with tf.Graph().as_default():


    input_data = tf.placeholder(tf.float32, [batch_size, timesteps, n_channels])

    with tf.name_scope('connected_input'):
        input_w = tf.get_variable('connected_w', [n_channels, rnn_size], initializer=tf.contrib.layers.xavier_initializer(seed=1), dtype=tf.float32)
        input_b = tf.get_variable('connected_b', [rnn_size], initializer=tf.constant_initializer(0.0), dtype=tf.float32)

    inputs = tf.nn.relu(tf.einsum('ijk,kl->ijl', input_data, input_w) + input_b)
    inputs = tf.unstack(inputs, num=timesteps, axis=1)

    lstm_cell = rnn.LSTMCell(4, state_is_tuple=True, initializer=tf.contrib.layers.xavier_initializer(seed=1))
    cell = rnn.MultiRNNCell([lstm_cell] * 1, state_is_tuple=True)

    initial_state = cell.zero_state(batch_size, tf.float32)

    outputs, state = rnn.static_rnn(cell, inputs, initial_state=initial_state, scope='RNN')


    with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=4)) as sess:
         tf.global_variables_initializer().run()

        outputs_val, state_val = sess.run([outputs, state], {input_data:np.reshape(np.array(0), [1,1,1])})

        saver = tf.train.Saver(var_list=tf.trainable_variables())        
        saver.save(sess, save_path='c:/tf_vars.dat')

这会产生以下警告:

WARNING:tensorflow:Error encountered when serializing LAYER_NAME_UIDS. Type is unsupported, or the types of the items don't match field type in CollectionDef. 'dict' object has no attribute 'name'

0 个答案:

没有答案