我正在尝试将单热张量转换为8二进制张量数组。我正在使用tf.keras.backend.argmax查找要转换为二进制表示形式的十进制索引。
在使用tf.keras.backend.constant列表时,以下代码可以完美运行。
import tensorflow as tf
from keras import backend as K
sess = tf.Session()
with sess.as_default():
a_dec = tf.constant([0,1,2], dtype=tf.int32)
a_bin = tf.mod(tf.bitwise.right_shift(tf.expand_dims(a_dec,1), tf.range(8)), 2)
out = sess.run(a_bin)
print(out)
输出:
[[0 0 0 0 0 0 0 0]
[1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0]]
但是,当我尝试以下操作时
sess = tf.Session()
with sess.as_default():
a_dec = K.constant([0,1,2], dtype=tf.int32)
index = K.constant([K.argmax(a_dec)], dtype=tf.int32)
a_bin = tf.mod(tf.bitwise.right_shift(tf.expand_dims(index,1), tf.range(8)), 2)
out = sess.run(a_bin)
print(out)
我收到以下错误:
TypeError: Expected float32, got list containing Tensors of type '_Message' instead.
那么,如何使 argmax 张量像恒定张量一样工作?