在Tensorflow中训练DecodeError
时,我一直在与一个奇怪的BoostedTreesClassifier
战斗,它有一个大数据集(约3000万行)。我坚信这与训练集的大小有关,因为如果将其减少到1000万行,我没有任何问题。另外,我的数据集中没有NaN。
我正在尝试为输入函数和分类器操作批处理参数,但是我尚未正确理解一些参数和函数,即:n_batches_per_layer
以及函数shuffle(buffer_size)
和{{1 }}。
有人对这些参数/功能有任何建议或建议吗?
我的Tensorflow版本是2.2.0。
错误:
batch(batch_size)
输入功能
这是我当前的输入函数-我正在使用bach大小和shuffle值,但是没有任何变化。这些是示例:
File "/home/myuser/gbt/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3053, in _as_graph_def
graph.ParseFromString(compat.as_bytes(data))
google.protobuf.message.DecodeError: Error parsing message
分类器:
这就是我使用分类器的方式。再次,我一直在玩 def input_fn():
dataset = tf.data.Dataset.from_tensor_slices((dict(x), y))
if shuffle:
dataset = dataset.shuffle(1000)
dataset = dataset.repeat(None)
dataset = dataset.batch(10000)
return dataset
。
n_batches_per_layer
错误日志:
params = {
'n_trees': 10,
'max_depth': 3,
'n_batches_per_layer': 1,
'center_bias': False
}
params = dict(params)
# classifier
gbt_est = tf.estimator.BoostedTreesClassifier(tf_feature_columns, **params)
# training
gbt_est.train(train_input_fn)
print(gbt_est.evaluate(eval_input_fn)) # NB: it crashes before printing this
# evaluation
results = gbt_est.evaluate(eval_input_fn)