Tensorflow模型占位符恢复

时间:2018-01-23 11:27:23

标签: python tensorflow

我想在Tensorflow中恢复,修改和重用(相当)复杂的模型,但是在使用占位符时,如何正确地传递feed_dict有一些困难。代码如下:

input_dir = "parallel_win_10_40_conv_3l_rnn"
input_file = "parallel_win_10_40_conv_3l_rnn"
saver = tf.train.import_meta_graph("./result/cnn_rnn_parallel/tune_rnn_layer/"+input_dir+"/model_"+input_file+".meta")

# # Method 1
# all_placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# cnn_in, rnn_in, Y = all_placeholders[0], all_placeholders[1], all_placeholders[2]
# keep_prob, phase_train = all_placeholders[3], all_placeholders[4]

# Method 2
cnn_in = tf.placeholder(tf.float32, shape=[None, input_height, input_width, input_channel_num], name='cnn_in')
rnn_in = tf.placeholder(tf.float32, shape=[None, n_time_step, n_input_ele], name='rnn_in')
Y = tf.placeholder(tf.float32, shape=[None, n_labels], name = 'Y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
phase_train = tf.placeholder(tf.bool, name='phase_train')

with tf.Session() as session:
    saver.restore(session, "./result/cnn_rnn_parallel/tune_rnn_layer/"+input_dir+"/model_"+input_file)

    test_cnn_batch = np.zeros(shape=[accuracy_batch_size], dtype=float)
    test_rnn_batch = np.zeros(shape=[accuracy_batch_size], dtype=float)

    offset = (accuracy_batch_size) % (test_y.shape[0] - accuracy_batch_size)
    test_cnn_batch = cnn_test_x[offset:(offset + accuracy_batch_size), :, :, :, :]
    test_cnn_batch = test_cnn_batch.reshape(len(test_cnn_batch) * window_size, input_height, input_width, 1)
    test_rnn_batch = rnn_test_x[offset:(offset + accuracy_batch_size), :, :]
    test_batch_y = test_y[offset:(offset + accuracy_batch_size), :]

    print(session.run('fin_m:0', feed_dict={cnn_in: test_cnn_batch, rnn_in: test_rnn_batch,
                                        Y: test_batch_y, keep_prob: 1.0, phase_train: False}))

当我使用方法1时,我收到一个错误:

  

TypeError:无法将feed_dict键解释为Tensor:无法将操作转换为Tensor。

当我使用方法2时,我得到一个不同的错误:

  

InvalidArgumentError(请参阅上面的回溯):您必须使用dtype float

为占位符张量'cnn_in'提供值

这两个错误都让我感到困惑,因为在保存模型之前,占位符的定义完全相同,所以它们不应该具有相同的类型(Operation或Tensor)?对于第二种方法,test_cnn_batch是一个带有浮点值的ndarray。我认为这可能是因为模型中的cnn_in是在saver = tf.train.import_meta_graph行中定义的(根据错误信息)。我认为重新定义之后可能有所帮助,但没有骰子。

这里发生了什么?这样做的正确方法是什么?我已经阅读了许多相关的问题,但它们没有直接解决这些问题。

感谢任何帮助。

1 个答案:

答案 0 :(得分:0)

您的方法1是错误的,因为您需要获取张量,而不是定义占位符的操作。由于您正在循环get_operations()的结果,因此您获得的是操作,而不是张量。

我们的方法2也是错误的,因为您没有在图表中获取占位符,而是定义了未与计算图表其余部分相关联的新占位符。

您必须做的是找到占位符的名称,然后按名称从图表中获取它们。 错误代码已经显示了您需要的名称之一:

  

InvalidArgumentError(请参阅上面的回溯):您必须提供值   对于占位符张量' cnn_in '用dtype float

然后你可以这样做:

cnn_in = tf.get_default_graph().get_tensor_by_name('cnn_in')

注意:您可能需要在张量名称附加:0

对图中所需的所有占位符重复相同的过程。