我想计算张量中两个值的出现次数。除了张量中不存在一个或两个值的情况外,以下代码有效。在这种情况下,它会崩溃(预期)错误:InvalidArgumentError: Expected begin and size arguments to be 1-D tensors of size 1, but got shapes [0] and [1] instead.
如何修改此代码(不使用条件),因此它只为丢失的值提供0计数而不是崩溃。
wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
return tf.slice(count, idx_val1, [1]) + tf.slice(count, idx_val2, [1])
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))
答案 0 :(得分:2)
你可以这样做:
05
但请注意,一般来说,comparing floating point numbers for equality is not the best option。具有一定容忍度的可能替代方案可以是:
wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
eq = tf.logical_or(tf.equal(t, val1), tf.equal(t, val2))
return tf.count_nonzero(eq)
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))
答案 1 :(得分:0)
我认为你可以做这样的事情
wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
temp = tf.cond(tf.greater(tf.shape(idx_val1)[0], 0),
lambda: tf.slice(count, idx_val1, [1]),
lambda: [0])
temp = temp + tf.cond(tf.greater(tf.shape(idx_val2)[0], 0),
lambda: tf.slice(count, idx_val2, [1]),
lambda: [0])
return temp
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))