“无法将ndarray转换为Tensor或Operation。”尝试从tensorflow中的session.run获取值时出错

时间:2017-05-20 17:09:12

标签: python numpy tensorflow

我在tensorflow中创建了一个暹罗网络。我正在使用以下代码计算两个张量之间的距离:

distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(question1_predictions, question2_predictions)), reduction_indices=1))

我能够毫无错误地训练模型。在推理部分,我正在检索distance张量,如下所示:

test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)

Tensorflow抛出错误:

  

Fetch参数数组([....],dtype = float32)具有无效类型,必须是字符串或Tensor。 (不能转换一个   ndarray进入Tensor或Operation。)

当我在训练部分的distance之前和之后打印session.run张量时,它显示为<class 'tensorflow.python.framework.ops.Tensor'>。因此,在distance推理部分中发生了使用numpy distance替换张量session.run。遵循推理部分的代码:

with graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer(), feed_dict={embedding_placeholder: embedding_matrix})
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))

    test_state = sess.run(initial_state)

    for ii, (x1, x2, batch_test_ids) in enumerate(get_test_batches(test_question1_features, test_question2_features, test_ids, batch_size), 1):
        feed = {question1_inputs: x1,
                question2_inputs: x2,
                keep_prob: 1,
                initial_state: test_state
               }
        test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)

1 个答案:

答案 0 :(得分:9)

看起来你用一个numpy数组distance = tf.sqrt(...)覆盖Tensor distance = sess.run(distance)

你的循环是罪魁祸首。将t_state, distance = sess.run([question1_final_state, distance]更改为t_state, distance_other = sess.run([question1_final_state, distance]