如何在简单的Tensorflow示例中打印混淆矩阵?

时间:2018-01-08 13:03:07

标签: python tensorflow

我试图调整Iris classification matrix example来学习Tensorflow的一些基本机制,我无法弄清楚如何打印混淆矩阵。下面是我的代码到目前为止和结果。要么我没有正确创建标签和/或预测,要么我没有正确处理混淆矩阵。任何帮助将不胜感激!

代码

import os

import six.moves.urllib.request as request
import tensorflow as tf

PATH = "/tmp/tf_dataset_and_estimator_apis"

# Fetch and store Training and Test dataset files
PATH_DATASET = PATH + os.sep + "dataset"
FILE_TRAIN = PATH_DATASET + os.sep + "iris_training.csv"
FILE_TEST = PATH_DATASET + os.sep + "iris_test.csv"
URL_TRAIN = "http://download.tensorflow.org/data/iris_training.csv"
URL_TEST = "http://download.tensorflow.org/data/iris_test.csv"


def downloadDataset(url, file):
    if not os.path.exists(PATH_DATASET):
        os.makedirs(PATH_DATASET)
    if not os.path.exists(file):
        data = request.urlopen(url).read()
        with open(file, "wb") as f:
            f.write(data)
            f.close()
downloadDataset(URL_TRAIN, FILE_TRAIN)
downloadDataset(URL_TEST, FILE_TEST)

tf.logging.set_verbosity(tf.logging.INFO)

# The CSV features in our training & test data
feature_names = [
    'SepalLength',
    'SepalWidth',
    'PetalLength',
    'PetalWidth']

# Create an input function reading a file using the Dataset API
# Then provide the results to the Estimator API


def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
    def decode_csv(line):
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        label = parsed_line[-1]  # Last element is the label
        del parsed_line[-1]  # Delete last element
        features = parsed_line  # Everything but last elements are the features
        d = dict(zip(feature_names, features)), label
        return d

    dataset = (tf.data.TextLineDataset(file_path)  # Read text file
               .skip(1)  # Skip header row
               .map(decode_csv))  # Transform each elem by applying decode_csv fn
    if perform_shuffle:
        # Randomizes input using a window of 256 elements (read into memory)
        dataset = dataset.shuffle(buffer_size=256)
    dataset = dataset.repeat(repeat_count)  # Repeats dataset this # times
    dataset = dataset.batch(32)  # Batch size to use
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

next_batch = my_input_fn(FILE_TRAIN, True)  # Will return 32 random elements

# Create the feature_columns, which specifies the input to our model
# All our input features are numeric, so use numeric_column for each one
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# Create a deep neural network regression classifier
# Use the DNNClassifier pre-made estimator
classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,  # The input features to our model
    hidden_units=[10, 10],  # Two layers, each with 10 neurons
    n_classes=3,
    model_dir=PATH)  # Path to where checkpoints etc are stored

classifier.train(
    input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))

predictions = list(classifier.predict(input_fn=lambda: my_input_fn(FILE_TEST, False, 1)))
print(
    "Test Samples, Raw Predictions:    {}\n"
    .format(predictions))

predicted_classes = [p["class_ids"][0] for p in predictions]
print(
    "Test Samples, Class Predictions:    {}\n"
    .format(predicted_classes))

labels = []
for line in open(FILE_TEST):
    parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
    label = parsed_line[-1]  # Last element is the label
    labels.append(label)
labels = labels[1:]
print(
    "Test Samples, Truth Labels:    {}\n"
    .format(labels))

confusion_matrix = tf.confusion_matrix(labels, predicted_classes,3)
for i in range(len(confusion_matrix)):
    for j in range(len(confusion_matrix[i])):
        print(confusion_matrix[i][j], end=' ')
    print()

结果

Test Samples, Raw Predictions:    [{'logits': array([-3.94134641,  5.46653843, -1.10556901], dtype=float32), 'probabilities': array([  8.19530323e-05,   9.98521388e-01,   1.39677816e-03], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-8.64749146,  0.87616217,  4.89346647], dtype=float32), 'probabilities': array([  1.29266971e-06,   1.76830795e-02,   9.82315600e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([ 12.76192856,   3.94970369, -13.86392498], dtype=float32), 'probabilities': array([  9.99851108e-01,   1.48879364e-04,   2.73195594e-12], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-3.94899917,  5.2370801 , -0.87788975], dtype=float32), 'probabilities': array([  1.02219477e-04,   9.97693360e-01,   2.20444612e-03], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-4.18660784,  4.82310486, -0.66269088], dtype=float32), 'probabilities': array([  1.21697682e-04,   9.95750666e-01,   4.12761979e-03], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-3.49290824,  6.48037815, -2.15846062], dtype=float32), 'probabilities': array([  4.66186284e-05,   9.99776304e-01,   1.77052832e-04], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([ 18.31637955,   2.90756798, -18.32689095], dtype=float32), 'probabilities': array([  9.99999762e-01,   2.03253521e-07,   1.21907086e-16], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-6.53988791,  2.51767302,  2.30166817], dtype=float32), 'probabilities': array([  6.45163964e-05,   5.53756475e-01,   4.46179003e-01], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-2.99833608,  6.07995462, -2.14333487], dtype=float32), 'probabilities': array([  1.14072849e-04,   9.99617696e-01,   2.68228352e-04], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-10.0059433 ,  -0.94183457,   6.65795517], dtype=float32), 'probabilities': array([  5.79086610e-08,   5.00306254e-04,   9.99499679e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([-10.03779984,  -0.58803403,   6.58065367], dtype=float32), 'probabilities': array([  6.05846608e-08,   7.69739854e-04,   9.99230146e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([ 14.65873909,   2.51680422, -14.8486042 ], dtype=float32), 'probabilities': array([  9.99994636e-01,   5.33116554e-06,   1.53151526e-13], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-8.05366135, -0.7239989 ,  5.43754053], dtype=float32), 'probabilities': array([  1.38016230e-06,   2.10456271e-03,   9.97894108e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([-4.74107504,  4.26416063,  0.01158825], dtype=float32), 'probabilities': array([  1.21028548e-04,   9.85852599e-01,   1.40263028e-02], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-3.1617856 ,  5.94625521, -2.06394577], dtype=float32), 'probabilities': array([  1.10722489e-04,   9.99557316e-01,   3.31911142e-04], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([ 16.06151962,   2.79136181, -16.13220215], dtype=float32), 'probabilities': array([  9.99998331e-01,   1.72521402e-06,   1.04338255e-14], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-0.1547718 ,  6.96881533, -4.85684061], dtype=float32), 'probabilities': array([  8.05216085e-04,   9.99187529e-01,   7.30852798e-06], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([ 12.97141838,   4.03725767, -14.2029705 ], dtype=float32), 'probabilities': array([  9.99868155e-01,   1.31791152e-04,   1.57854014e-12], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([ 15.51482964,   2.89622927, -15.66319656], dtype=float32), 'probabilities': array([  9.99996662e-01,   3.30986154e-06,   2.88107042e-14], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-8.73975182, -1.18648708,  5.97232819], dtype=float32), 'probabilities': array([  4.07649509e-07,   7.77370529e-04,   9.99222159e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([ 15.64372635,   2.90897799, -15.6467886 ], dtype=float32), 'probabilities': array([  9.99997020e-01,   2.94691472e-06,   2.57454428e-14], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-4.81624699,  3.6293087 ,  0.63097322], dtype=float32), 'probabilities': array([  2.04605545e-04,   9.52304006e-01,   4.74914126e-02], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-8.80371475, -0.39039207,  5.63731527], dtype=float32), 'probabilities': array([  5.33696152e-07,   2.40521529e-03,   9.97594178e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([-6.91917133,  2.50265408,  2.4683063 ], dtype=float32), 'probabilities': array([  4.11623223e-05,   5.08565187e-01,   4.91393685e-01], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-2.83097792,  6.94565678, -2.89881945], dtype=float32), 'probabilities': array([  5.67562929e-05,   9.99890208e-01,   5.30335128e-05], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-4.42013788,  4.92947769, -0.53595734], dtype=float32), 'probabilities': array([  8.66248956e-05,   9.95701015e-01,   4.21231333e-03], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([ 17.43722343,   3.1321764 , -17.74347305], dtype=float32), 'probabilities': array([  9.99999404e-01,   6.12910128e-07,   5.26281687e-16], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object)}, {'logits': array([-4.5709672 ,  5.86139631, -1.07322729], dtype=float32), 'probabilities': array([  2.94338297e-05,   9.98998106e-01,   9.72513983e-04], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}, {'logits': array([-9.68447876, -1.25304651,  6.60897779], dtype=float32), 'probabilities': array([  8.38830658e-08,   3.84945248e-04,   9.99614954e-01], dtype=float32), 'class_ids': array([2]), 'classes': array([b'2'], dtype=object)}, {'logits': array([-3.40590096,  6.96505642, -2.54329181], dtype=float32), 'probabilities': array([  3.13259734e-05,   9.99894381e-01,   7.42216871e-05], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object)}]

Test Samples, Class Predictions:    [1, 2, 0, 1, 1, 1, 0, 1, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1, 2, 1, 1, 1, 0, 1, 2, 1]

Test Samples, Truth Labels:    [<tf.Tensor 'DecodeCSV_1:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_2:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_3:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_4:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_5:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_6:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_7:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_8:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_9:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_10:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_11:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_12:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_13:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_14:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_15:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_16:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_17:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_18:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_19:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_20:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_21:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_22:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_23:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_24:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_25:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_26:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_27:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_28:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_29:4' shape=() dtype=int32>, <tf.Tensor 'DecodeCSV_30:4' shape=() dtype=int32>]

Traceback (most recent call last):
  File "/Users/timmolter/workspaces/workspace_tf/HelloTensorFlow/src/iris_DNN_Classifier.py", line 147, in <module>
    for i in range(len(confusion_matrix)):
TypeError: object of type 'Tensor' has no len()

3 个答案:

答案 0 :(得分:1)

这是我想出的解决方案:

代码

accuracy_score = classifier.evaluate(input_fn=lambda: my_input_fn(FILE_TEST, False, 1))["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

predictions = list(classifier.predict(input_fn=lambda: my_input_fn(FILE_TEST, False, 1)))  

predicted_classes = [p["class_ids"][0] for p in predictions]
print(
    "Test Samples, Class Predictions:    {}\n"
    .format(predicted_classes))

# truth labels
with open(FILE_TEST,'r') as f:
    lines = f.readlines()[1:]
    reader = csv.reader(lines, delimiter=',')
    truth_labels = [int(row[-1]) for row in reader]
print(
    "Test Samples, Class Truth Labels:    {}\n"
    .format(truth_labels))

with tf.Session() as sess:
    confusion_matrix = tf.confusion_matrix(labels=truth_labels, predictions=predicted_classes, num_classes=3)
    confusion_matrix_to_Print = sess.run(confusion_matrix)
    print(confusion_matrix_to_Print)

输出

Test Accuracy: 0.966667

Test Samples, Class Predictions:    [1, 2, 0, 1, 1, 1, 0, 1, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1, 2, 1, 1, 1, 0, 1, 2, 1]

Test Samples, Class Truth Labels:    [1, 2, 0, 1, 1, 1, 0, 2, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1, 2, 1, 1, 1, 0, 1, 2, 1]

[[ 8  0  0]
 [ 0 14  0]
 [ 0  1  7]]

答案 1 :(得分:0)

尝试

for i in range(confusion_matrix.shape[0].value):

而不是

for i in range(len(confusion_matrix)):

(另见How to get Tensorflow tensor dimensions (shape) as int values?

答案 2 :(得分:0)

这是我的3个标签类的代码:

tf.keras.backend.clear_session()

model = tf.keras.models.Sequential([Conv2D(32, 3, activation='relu'), 
                                    Flatten(), 
                                    ..., 
                                    Dense(3)])

model.compile(loss=[tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)],optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=16, epochs=5, validation_split=0.2, validation_batch_size=16)

model.evaluate(x_test, y_test)

predictions = model.predict(x_test)
predictions = tf.argmax(predictions, axis=1)

cm = tf.math.confusion_matrix(y_test, predictions, 3)

print(cm)