TensorFlow LinearClassifier断言失败:[标签必须<= n_classes-1]

时间:2019-07-17 15:41:00

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我收到一条错误消息,指出我拥有的标签数量大于n_classes-我正在使用的tf.estimator.LinearClassifier的数量。

我认为这与用于训练和测试的input_fn有关,后者确定功能和标签。我已经为此测试了不同的配置,但是找不到正确的答案。我正在使用的数据是包含4个int值的CSV,最后一个是标签。我正在Windows 3.6上的Python 3.6上运行。

def my_input_fn(data_file, num_epochs, batch_size):
    dataset = tf.data.experimental.make_csv_dataset(
        data_file,
        batch_size=batch_size,
        column_names=_CSV_COLUMNS, # ['int1', 'int2', 'int3', 'int4'] 
        label_name='int4',
        na_value="?",
        num_epochs=num_epochs,
        ignore_errors=True)
    return dataset

train_inpf = functools.partial(my_input_fn, train_file, num_epochs=2, shuffle=True, batch_size=32)
test_inpf = functools.partial(my_input_fn, test_file, num_epochs=1, shuffle=False, batch_size=1)

如果有用,这就是我设置分类器的方式:将用作要素的3个int列指定为分类数据。

col1 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int1', column_uniques_lists['int1'], dtype=tf.int64)

col2 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int2', column_uniques_lists['int2'], dtype=tf.int64)


col3 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int3', column_uniques_lists['int3'], dtype=tf.int64)

my_categorical_columns = [col1,col2,col3]

classifier = tf.estimator.LinearClassifier(feature_columns=my_categorical_columns,                                           
n_classes=len(column_uniques_lists['int4']), model_dir='.\\SaveLC\\model_dir')

column_uniques_lists是一个字典,其中包含每一列中包含的所有唯一值。

int4列中有7个唯一值,每个值对应一个类,因此我希望基于[int1,int2,int3]输入对应于int4的预测来使模型运行。 >

0 个答案:

没有答案