如何从TensorFlow预测中获取类标签

时间:2017-03-27 06:11:54

标签: machine-learning tensorflow neural-network

我在TF中有一个分类模型,可以获得下一个类(preds)的概率列表。现在我想选择最高元素(argmax)和显示其类标签

这可能看起来很愚蠢,但是如何才能获得与预测张量中的位置匹配的类标签?

        feed_dict={g['x']: current_char}
        preds, state = sess.run([g['preds'],g['final_state']], feed_dict)
        prediction = tf.argmax(preds, 1)

preds为我提供了每个班级的预测矢量。当然必须有一个简单的方法来输出最可能的类(标签)?

有关我的模特的一些信息:

x = tf.placeholder(tf.int32, [None, num_steps], name='input_placeholder')
y = tf.placeholder(tf.int32, [None, 1], name='labels_placeholder')
batch_size = batch_size = tf.shape(x)[0]  
x_one_hot = tf.one_hot(x, num_classes)
rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in
              tf.split(x_one_hot, num_steps, 1)] 

tmp = tf.stack(rnn_inputs)
print(tmp.get_shape())
tmp2 = tf.transpose(tmp, perm=[1, 0, 2])
print(tmp2.get_shape())
rnn_inputs = tmp2


with tf.variable_scope('softmax'):
    W = tf.get_variable('W', [state_size, num_classes])
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))


rnn_outputs = rnn_outputs[:, num_steps - 1, :]
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits)

2 个答案:

答案 0 :(得分:0)

您可以使用tf.reduce_max()。我会推荐你​​this answer。 让我知道它是否有效 - 如果没有,将进行编辑。

答案 1 :(得分:0)

请注意,有时有多种方法可以加载数据集。例如,对于 fashion MNISTtutorial 可以引导您使用 load_data(),然后创建您自己的结构来解释预测。但是,您也可以在安装 tensorflow_datasets.load(...) 后使用 tensorflow-datasets 来加载这些数据,例如 here,它可以让您访问某些 DatasetInfo。因此,例如,如果您的预测是 9,您可以通过以下方式判断它是引导:

import tensorflow_datasets as tfds
_, ds_info = tfds.load('fashion_mnist', with_info=True)
print(ds_info.features['label'].names[9])