张量流数据集tf.estimator.inputs.numpy_input_fn

时间:2018-02-23 21:22:34

标签: python tensorflow dataset

我在tensorflow中编写用于从光盘读取图像和标签的代码,然后尝试调用tf.estimator.inputs.numpy_input_fn。如何传递整个数据集而不是单个图像。我的代码如下:

filenames = tf.constant(filenames)
labels = tf.constant(labels)

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset_batched = dataset.batch(10)
iterator = dataset_batched.make_one_shot_iterator()
features, labels = iterator.get_next()

with tf.Session() as sess:

  print(dataset_batched)
  print(np.shape(sess.run(features)))
  print(np.shape(sess.run(labels)))

  mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
  train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(sess.run(features))},
                                                  y=np.array(sess.run(labels)),
                                                  batch_size=1,
                                                  num_epochs=None,
                                                  shuffle=False)
  mnist_classifier.train(input_fn=train_input_fn, steps=1)

我的问题是如何在此处传递数据集x={"x": np.array(sess.run(features))}

1 个答案:

答案 0 :(得分:6)

此处numpy_input_fn没有必要/使用。您应该将顶部的代码包装到返回my_input_fn的函数(例如iterator.get_next())中,然后将input_fn=my_input_fn传递给train调用。这会将完整数据集分批传递给训练代码。

numpy_input_fn适用于您已经拥有数组中的完整数据集并希望快速进行批处理/重排/重复等的方式。