我打算如何/在哪里为这个tensorflow随机森林教程提供训练数据?

时间:2017-07-09 03:26:38

标签: python machine-learning tensorflow random-forest training-data

我正在关注tensorflows示例代码,该代码允许您在mnist数据集上设置随机林。

我有来自github的以下简短代码,应该训练一个随机森林:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
        import metric_spec
from tensorflow.contrib.learn.python.learn.estimators\
        import estimator
from tensorflow.contrib.tensor_forest.client\
        import eval_metrics
from tensorflow.contrib.tensor_forest.client\
        import random_forest
from tensorflow.contrib.tensor_forest.python\
        import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.platform import app

FLAGS = None


def build_estimator(model_dir):
  """Build an estimator."""
  params = tensor_forest.ForestHParams(
      num_classes=10, num_features=784,
      num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
  graph_builder_class = tensor_forest.RandomForestGraphs
  if FLAGS.use_training_loss:
    graph_builder_class = tensor_forest.TrainingLossForest
  # Use the SKCompat wrapper, which gives us a convenient way to split
  # in-memory data like MNIST into batches.
  return estimator.SKCompat(random_forest.TensorForestEstimator(
      params, graph_builder_class=graph_builder_class,
      model_dir=model_dir))


def train_and_eval():
  """Train and evaluate the model."""
  model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
  print('model directory = %s' % model_dir)

  est = build_estimator(model_dir)

  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)

  est.fit(x=mnist.train.images, y=mnist.train.labels,
          batch_size=FLAGS.batch_size)

  metric_name = 'accuracy'
  metric = {metric_name:
            metric_spec.MetricSpec(
                eval_metrics.get_metric(metric_name),
                prediction_key=eval_metrics.get_prediction_key(metric_name))}

  results = est.score(x=mnist.test.images, y=mnist.test.labels,
                      batch_size=FLAGS.batch_size,
                      metrics=metric)
  for key in sorted(results):
    print('%s: %s' % (key, results[key]))


def main(_):
  train_and_eval()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--model_dir',
      type=str,
      default='',
      help='Base directory for output models.'
  )
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/data/',
      help='Directory for storing data'
  )
  parser.add_argument(
      '--train_steps',
      type=int,
      default=1000,
      help='Number of training steps.'
  )
  parser.add_argument(
      '--batch_size',
      type=str,
      default=1000,
      help='Number of examples in a training batch.'
  )
  parser.add_argument(
      '--num_trees',
      type=int,
      default=100,
      help='Number of trees in the forest.'
  )
  parser.add_argument(
      '--max_nodes',
      type=int,
      default=1000,
      help='Max total nodes in a single tree.'
  )
  parser.add_argument(
      '--use_training_loss',
      type=bool,
      default=False,
      help='If true, use training loss as termination criteria.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  app.run(main=main, argv=[sys.argv[0]] + unparsed)

我的问题是,当我运行它时,它会转到以下行:

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)

然后崩溃并出现以下错误:

IOError: [Errno socket error] EOF occurred in violation of protocol (_ssl.c:590)

似乎有点困惑/无法获取mnist数据。我尝试将/ tmp / data /目录添加到运行python文件的位置,然后转到http://yann.lecun.com/exdb/mnist/并下载mnist数据文件标题t10k-images-idx3-ubyte,t10k-labels-idx1-ubyte ,train-images-idx3-ubyte和train-labels-idx1-ubyte并将它们放入/ tmp / data /目录,但它给出了完全相同的错误。我必须遗漏一些非常明显的东西,但我找不到tensorflow github解释如何获取/设置和从脚本访问这些数据的位置。任何想法?

1 个答案:

答案 0 :(得分:2)

试试这个:

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)