计算张量中两个值的出现次数

时间:2017-07-26 15:24:28

标签: tensorflow

我想计算张量中两个值的出现次数。除了张量中不存在一个或两个值的情况外,以下代码有效。在这种情况下,它会崩溃(预期)错误:InvalidArgumentError: Expected begin and size arguments to be 1-D tensors of size 1, but got shapes [0] and [1] instead.

如何修改此代码(不使用条件),因此它只为丢失的值提供0计数而不是崩溃。

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])

sess = tf.Session()
sess.run(tf.global_variables_initializer())

def count_occurrences(t, val1, val2):
    y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
    idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
    idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
    return tf.slice(count, idx_val1, [1]) + tf.slice(count, idx_val2, [1])

print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

2 个答案:

答案 0 :(得分:2)

你可以这样做:

05

但请注意,一般来说,comparing floating point numbers for equality is not the best option。具有一定容忍度的可能替代方案可以是:

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])

sess = tf.Session()
sess.run(tf.global_variables_initializer())

def count_occurrences(t, val1, val2):
    eq = tf.logical_or(tf.equal(t, val1), tf.equal(t, val2))
    return tf.count_nonzero(eq)

print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

答案 1 :(得分:0)

我认为你可以做这样的事情

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])

sess = tf.Session()
sess.run(tf.global_variables_initializer())

def count_occurrences(t, val1, val2):
    y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
    idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
    idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
    temp = tf.cond(tf.greater(tf.shape(idx_val1)[0], 0), 
               lambda: tf.slice(count, idx_val1, [1]), 
               lambda: [0]) 
    temp = temp + tf.cond(tf.greater(tf.shape(idx_val2)[0], 0), 
                 lambda: tf.slice(count, idx_val2, [1]), 
                 lambda: [0])
    return temp

print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))