Tensorflow:ValueError:Feature(key:c1)不能具有等级0

时间:2018-03-03 20:15:52

标签: python tensorflow

我正在训练和评估我的数据集,但是当我尝试预测时(即使用新的输入函数),我得到错误:

  

ValueError:Feature(key:c1)不能具有等级0。

我以这种方式在模型函数中声明我的列:

categorical_column1 = tf.feature_column.categorical_column_with_hash_bucket(key="c1", hash_bucket_size=5)
numeric_feature_column1 = tf.feature_column.numeric_column(key="n2", dtype=tf.float32, normalizer_fn=lambda x: tf.subtract(x, tf.reduce_mean(x)))
numeric_feature_column2 = tf.feature_column.numeric_column(key="n3", dtype=tf.float32, normalizer_fn=lambda x: tf.subtract(x, tf.reduce_mean(x)))
feature_columns = [tf.feature_column.indicator_column(categorical_column1), numeric_feature_column1, numeric_feature_column2]

我的数据集示例和我的输入函数如下所示:

features = {
    'c1': ["Sony", "Samsung", "Sony", "Sony", "Samsung", "Samsung", "Sony"], 
    'n2': [24,20,18,26,24,30,10],
    'n3': [1,0,0,1,1,0,1]
    }
features_test = {
    'c1': ["Samsung", "Sony"],
    'n2': [20,18],
    'n3': [0,0]
}

labels = [0,1,0,0,1,0,0]
#labels = tf.one_hot(labels, depth=2)
labels_test = [1,0]


def my_input_fn(features, labels, perform_shuffle=False, repeat_count=1):

    train_dataset = tf.data.Dataset.from_tensor_slices((features, labels))

    if perform_shuffle:
        # Randomizes input using a window of 512 elements (read into memory)
        train_dataset = train_dataset.shuffle(256)
    train_dataset = train_dataset.repeat(repeat_count) # Repeats dataset this # times
    train_dataset = train_dataset.batch(BATCH_SIZE)  # Batch size to use
    train_dataset = train_dataset.prefetch(1)

    # create a iterator of the correct shape and type
    iterator = train_dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

最后我提供预测部分的方式:

def new_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices(features_test).repeat(1)
    iterator2 = dataset.make_one_shot_iterator()
    next_feature_batch = iterator2.get_next()
    return next_feature_batch, None  # In prediction, we have no labels

# Predict all our prediction_input
predict_results = classifier.predict(input_fn=new_input_fn)

classifier用于训练/评估/预测,唯一的区别是输入功能...什么可能导致此错误?

0 个答案:

没有答案