十进制到二进制张量转换

时间:2019-05-25 12:07:20

标签: python-3.x tensorflow keras

我正在尝试将单热张量转换为8二进制张量数组。我正在使用tf.keras.backend.argmax查找要转换为二进制表示形式的十进制索引。

在使用tf.keras.backend.constant列表时,以下代码可以完美运行。

import tensorflow as tf
from keras import backend as K
sess = tf.Session()
with sess.as_default():
    a_dec = tf.constant([0,1,2], dtype=tf.int32)
    a_bin = tf.mod(tf.bitwise.right_shift(tf.expand_dims(a_dec,1), tf.range(8)), 2)
    out = sess.run(a_bin)
print(out)

输出:

[[0 0 0 0 0 0 0 0]
 [1 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0]]

但是,当我尝试以下操作时

sess = tf.Session()
with sess.as_default():
    a_dec = K.constant([0,1,2], dtype=tf.int32)
    index = K.constant([K.argmax(a_dec)], dtype=tf.int32)
    a_bin = tf.mod(tf.bitwise.right_shift(tf.expand_dims(index,1), tf.range(8)), 2)
    out = sess.run(a_bin)
print(out)

我收到以下错误:

TypeError: Expected float32, got list containing Tensors of type '_Message' instead.

那么,如何使 argmax 张量像恒定张量一样工作?

0 个答案:

没有答案