我在tensorflow中创建了一个暹罗网络。我正在使用以下代码计算两个张量之间的距离:
distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(question1_predictions, question2_predictions)), reduction_indices=1))
我能够毫无错误地训练模型。在推理部分,我正在检索distance
张量,如下所示:
test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)
Tensorflow抛出错误:
Fetch参数数组([....],dtype = float32)具有无效类型,必须是字符串或Tensor。 (不能转换一个 ndarray进入Tensor或Operation。)
当我在训练部分的distance
之前和之后打印session.run
张量时,它显示为<class 'tensorflow.python.framework.ops.Tensor'>
。因此,在distance
推理部分中发生了使用numpy distance
替换张量session.run
。遵循推理部分的代码:
with graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer(), feed_dict={embedding_placeholder: embedding_matrix})
saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
test_state = sess.run(initial_state)
for ii, (x1, x2, batch_test_ids) in enumerate(get_test_batches(test_question1_features, test_question2_features, test_ids, batch_size), 1):
feed = {question1_inputs: x1,
question2_inputs: x2,
keep_prob: 1,
initial_state: test_state
}
test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)
答案 0 :(得分:9)
看起来你用一个numpy数组distance = tf.sqrt(...)
覆盖Tensor distance = sess.run(distance)
。
你的循环是罪魁祸首。将t_state, distance = sess.run([question1_final_state, distance]
更改为t_state, distance_other = sess.run([question1_final_state, distance]