使用大数据集训练TensorFlow BoostedTreesClassifier时发生DecodeError

时间:2020-06-21 18:04:42

标签: python tensorflow tensorflow2.0 decision-tree

在Tensorflow中训练DecodeError时,我一直在与一个奇怪的BoostedTreesClassifier战斗,它有一个大数据集(约3000万行)。我坚信这与训练集的大小有关,因为如果将其减少到1000万行,我没有任何问题。另外,我的数据集中没有NaN。

我正在尝试为输入函数和分类器操作批处理参数,但是我尚未正确理解一些参数和函数,即:n_batches_per_layer以及函数shuffle(buffer_size)和{{1 }}。

有人对这些参数/功能有任何建议或建议吗?

我的Tensorflow版本是2.2.0。


错误:

batch(batch_size)

输入功能

这是我当前的输入函数-我正在使用bach大小和shuffle值,但是没有任何变化。这些是示例:

File "/home/myuser/gbt/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3053, in _as_graph_def
        graph.ParseFromString(compat.as_bytes(data))
    google.protobuf.message.DecodeError: Error parsing message

分类器:

这就是我使用分类器的方式。再次,我一直在玩 def input_fn(): dataset = tf.data.Dataset.from_tensor_slices((dict(x), y)) if shuffle: dataset = dataset.shuffle(1000) dataset = dataset.repeat(None) dataset = dataset.batch(10000) return dataset

n_batches_per_layer

错误日志:

    params = {
        'n_trees': 10,
        'max_depth': 3,
        'n_batches_per_layer': 1,            
        'center_bias': False
    }

    params = dict(params)

    # classifier
    gbt_est = tf.estimator.BoostedTreesClassifier(tf_feature_columns, **params)

    # training
    gbt_est.train(train_input_fn)
    print(gbt_est.evaluate(eval_input_fn)) # NB: it crashes before printing this

    # evaluation
    results = gbt_est.evaluate(eval_input_fn)

0 个答案:

没有答案