我正在尝试用我自己的数据集训练一个广泛而深入的神经网络(https://www.tensorflow.org/tutorials/wide_and_deep)。 我的数据集功能是数值。它有许多列(280),它按类别进行平衡。这是一个带有两个标签的二进制分类问题:0和1.所有代码都与wide_deep.py(https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py)相同,因为数据集有一些变化:
_CSV_COLUMNS = [ names of my columns... ]
_NUM_EXAMPLES = {
'train': 40000,
'validation': 10000,
}
...
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
columns = []
for a in range(0, len(_CSV_COLUMNS)):
_CSV_COLUMN_DEFAULTS.append([0.0])
if (_CSV_COLUMNS[a] != "category"):
columns.append(tf.feature_column.numeric_column(_CSV_COLUMNS[a]))
wide_columns = columns
deep_columns = columns
return wide_columns, deep_columns
...
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('category')
return features, tf.equal(labels, 0)
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
....
results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, FLAGS.batch_size))
问题在于,损失在开始时会略微下降,然后开始波动。
我试过了:
降低学习率,如下:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
dnn_optimizer=tf.train.AdagradOptimizer(learning_rate=0.000001),
config=run_config)
学习率有多低并不重要,损失不会减少。
输出如下:
INFO:tensorflow:loss = 64.00549, step = 1
INFO:tensorflow:loss = 41.65817, step = 101 (10.582 sec)
INFO:tensorflow:loss = 35.131912, step = 201 (8.954 sec)
INFO:tensorflow:loss = 41.854298, step = 301 (8.012 sec)
INFO:tensorflow:loss = 24.876358, step = 401 (7.573 sec)
INFO:tensorflow:loss = 44.788383, step = 501 (9.059 sec)
INFO:tensorflow:loss = 33.58178, step = 601 (7.904 sec)
INFO:tensorflow:loss = 35.278023, step = 701 (7.751 sec)
INFO:tensorflow:loss = 28.880415, step = 801 (7.787 sec)
INFO:tensorflow:loss = 33.041504, step = 901 (8.178 sec)
...
它永远不会再高达64,但它会波动并且不会低于24.这似乎是在学习一些东西,因为准确度是0.81-0.82(它不会比那更好,或) 这是什么意思?我的数据集有问题吗?这是否意味着问题无法解决?