张量流,检查张量中哪些值是整数

时间:2018-10-23 01:20:53

标签: python tensorflow

我有一个张量:

numbers = tf.constant([[4.00, 3.33], [2.34, 7.00]])

我想做的是得到一个张量,其张量与“数字”相同,但在数字为整数的索引处为1,而在不是数字的索引处为0,如下所示:

ans = [[1, 0],[0, 1]]

我猜我可能必须使用tf.where()?我真的不确定如何使用张量流来做这样的事情。谢谢

1 个答案:

答案 0 :(得分:0)

import tensorflow as tf
import numpy as np

tf.reset_default_graph()
with tf.Session() as sess:
    fake_data = np.asarray([[1,2.4, 3.5], [3.4, 2.00, 10.001], [105.1, 100, 10]])
    a = tf.constant(data)

    # find where floor == actual value (thus, is a whole number)
    mask = tf.equal(tf.floor(data), data)

    # Get the indices
    idx = tf.where(mask)


    print(sess.run(a))
    print(sess.run(idx))

这应该可以解决问题:)我尝试对其进行评论,这样我所做的事情就很清楚了,我认为这很容易理解。我写的内容基于this comment