如何从张量流中的数据集类获取10K MNIST图像的子集?

时间:2019-02-12 19:40:58

标签: python tensorflow

我发现了以下方法来在tensorflow中获取mnist数据集:

def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):

  def _input_fn():
    images_batch, labels_batch = tf.train.shuffle_batch(
        tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
        batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue,
        enqueue_many=True,
        num_threads=4)
    features_map = {'images': images_batch}
    return features_map, labels_batch

  return _input_fn

    data = tf.contrib.learn.datasets.mnist.load_mnist()

    train_input_fn = get_input_fn(data.train, batch_size=256)
    eval_input_fn = get_input_fn(data.validation, batch_size=5000)

data变量是Dataset对象。 这种方法对我来说还很不清楚,我无法弄清楚如何将60K数据集转换为10K数据集。

当我执行以下操作时:

data = tf.contrib.learn.datasets.mnist.load_mnist().take(10000)

我收到错误消息:

AttributeError: 'Datasets' object has no attribute 'take'

但是文档提供了以下方法: enter image description here

谢谢您的帮助!

1 个答案:

答案 0 :(得分:0)

不推荐使用contrib模块中的此功能。您可以使用tf.keras.datasets.mnist.load_data()。根据{{​​3}},它返回

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 

因此,要对其应用任何功能,您需要将其加载到数据集对象中。

train, test = tf.keras.datasets.mnist.load_data(path='mnist.npz')
dataset_train = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
dataset_test = tf.data.Dataset.from_tensor_slices((test[0], test[1]))

然后,您可以对dataset_traindataset_test对象应用洗牌,批处理,拍摄或任何地图功能