我正在解析TFRecords,它们为我提供了一个数字值标签。但是我在读取原始记录时需要将该值转换为分类向量。我怎样才能做到这一点。这是用于读取原始记录的代码段:
def parse(example_proto):
features={'label':: tf.FixedLenFeature([], tf.int64), ...}
parsed_features = tf.parse_single_example(example_proto, features)
label = tf.cast(parsed_features['label'], tf.int32)
# at this point label is a Tensor which holds numerical value
# but I need to return a Tensor which holds categorical vector
# for instance, if my label is 1 and I have two classes
# I need to return a vector [1,0] which represents categorical values
我知道有tf.keras.utils.to_categorical
函数,但是它不使用张量作为输入。
答案 0 :(得分:2)
您只需要将标签转换为其一键式表示(即您描述的表示)即可:
label = tf.cast(parsed_features['label'], tf.int32)
num_classes = 2
label = tf.one_hot(label, num_classes)