import math
import numpy as np
import tensorflow as tf
myx=np.array([2,4,5])
myy=np.array([10,3,7,8,6,4,11,18,1])
Xxx=np.transpose(np.repeat(myx[:, np.newaxis], myy.size , axis=1))
Yyy=np.repeat(myy[:, np.newaxis], myx.size , axis=1)
X = tf.placeholder(tf.float64, shape=(myy.size,myx.size))
Y = tf.placeholder(tf.float64, shape=(myy.size,myx.size))
calp=tf.constant(1)
with tf.device('/cpu:0'):
#minCord=tf.argmin(tfslic,0)
dist = tf.abs(tf.subtract(X,Y))
i = tf.placeholder(dtype='int32')
def condition(i):
return i < 2
def b(i):
dist = tf.abs(tf.subtract(X,Y))
tfslic=tf.slice(dist,[0,i],[myy.size,1])
minVal=tf.reduce_min(tfslic,0)
y = tf.cond(tf.less_equal(minVal, 1), lambda: tf.argmin(tfslic,0), lambda: 99999)
return i+1, y
i, r = tf.while_loop(condition, b, [i])
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
dmat=sess.run(i, feed_dict={X:Xxx, Y: Yyy, i:0})
sess.close()
print(dmat)
我一直收到错误:
ValueError: Shape must be rank 0 but is rank 1 for 'while_50/cond/Switch'
(op: 'Switch') with input shapes: [1], [1].
有人可以帮我解决这个错误吗?我试图让这个张量流“循环”循环工作。
基本上我尝试用张量流框架做一个贪婪的1对1匹配数组“myx”和“myy”。
答案 0 :(得分:0)
tf.cond(pred, true_fn, false_fn)
函数要求pred
是标量(“rank 0”)张量。在你的程序中,它是长度为1的向量(“等级1”)张量。
有很多方法可以解决这个问题。例如,您可以使用tf.reduce_min()
而不指定轴来计算全局最小值tfslic
作为标量:
minVal = tf.reduce_min(tfslic)
...或者您可以明确使用tf.reshape()
将参数设为tf.cond()
标量:
y = tf.cond(tf.less_equal(tf.reshape(minVal, []), 1), ...)
我冒昧地稍微修改你的程序以获得一个有效的版本。按照评论查看必要的更改位置:
with tf.device('/cpu:0'):
dist = tf.abs(tf.subtract(X,Y))
# Use an explicit shape for `i`.
i = tf.placeholder(dtype='int32', shape=[])
# Add a second unused argument to `condition()`.
def condition(i, _):
return i < 2
# Add a second unused argument to `b()`.
def b(i, _):
dist = tf.abs(tf.subtract(X,Y))
# Could use `tfslic = dist[0:myy.size, i]` here to avoid later reshapes.
tfslic = tf.slice(dist, [0,i], [myy.size,1])
# Drop the `axis` argument from `tf.reduce_min()`
minVal=tf.reduce_min(tfslic)
y = tf.cond(
tf.less_equal(minVal, 1),
# Reshape the output of `tf.argmin()` to be a scalar.
lambda: tf.reshape(tf.argmin(tfslic, 0), []),
# Explicitly convert the false-branch value to `tf.int64`.
lambda: tf.constant(99999, dtype=tf.int64))
return i+1, y
# Add a dummy initial value for the second loop variable.
# Rename the first return value to `i_out` to avoid clashing with `i` above.
i_out, r = tf.while_loop(condition, b, [i, tf.constant(0, dtype=tf.int64)])
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Fetch the value of `i_out`.
dmat = sess.run(i_out, feed_dict={X:Xxx, Y: Yyy, i:0})