tensorflow使用tf.data.TextLineDataset()返回AttributeError:'list'对象没有属性'get_shape'

时间:2018-03-10 03:03:14

标签: python-3.x tensorflow

import tensorflow as tf
feature_names = ['length_1',
             'width_1',
             'length_2',
             'width_2']
FILE_TRAIN = 'iris_training.csv'
FILE_TEST = 'iris_test.csv'
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
    def decode_csv(line):
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        label = parsed_line[-1:]  # Last element is the label
        del parsed_line[-1]  # Delete last element
        features = parsed_line
        d = dict(zip(feature_names, features))
        return d, label

    dataset = tf.data.TextLineDataset(file_path).skip(1).map(decode_csv)  # Transform each elem by applying decode_csv fn
    if perform_shuffle:
        # Randomizes input using a window of 256 elements (read into memory)
        dataset = dataset.shuffle(buffer_size=256)
    dataset = dataset.repeat(repeat_count)  # Repeats dataset this # times
    dataset = dataset.batch(32)  # Batch size to use
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels
next_batch = my_input_fn(FILE_TRAIN, True)  # Will return 32 random elements
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]
classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,  # The input features to our model
    hidden_units=[10, 10],  # Two layers, each with 10 neurons
    n_classes=3,
    model_dir='iris_model')  # Path to where checkpoints etc are stored
classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))
evaluate_result = classifier.evaluate(input_fn=lambda: my_input_fn(FILE_TEST, False, 4))
print("Evaluation results")

for key in evaluate_result:
    print("   {}, was: {}".format(key, evaluate_result[key]))

This is my code from google's example

This is wrong message

代码和谷歌的例子一样,我找不到错误的地方,谢谢!

2 个答案:

答案 0 :(得分:2)

在此功能def decode_csv(line)中,将其编辑如下,请将:放在此行label = parsed_line[-1]

def decode_csv(line):
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        label = parsed_line[-1]  # Last element is the label
        del parsed_line[-1]  # Delete last element
        features = parsed_line  # Everything but last elements are the features
        d = dict(zip(feature_names, features))
        return d, label

答案 1 :(得分:0)

我也遵循谷歌的例子并遇到了同样的问题。我注意到标签实际上是一个列表,所以返回d,label [0]对我有效。