Tensor Flow CNN MNIST示例:模型中批量大小的工作方式

时间:2018-06-19 16:33:31

标签: tensorflow batch-processing mnist convolutional-neural-network tensorflow-estimator

在CNN MNIST tensorflow示例中,我不了解批处理大小的工作原理,当他们调用模型时,他们将bach的大小指定为100:

train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,shuffle=True)
mnist_classifier.train(input_fn=train_input_fn,steps=20000,hooks=[logging_hook])

但是在调用模型时:

def cnn_model_fn(features, labels, mode):
  # Input Layer
  # Reshape X to 4-D tensor: [batch_size, width, height, channels]
  # MNIST images are 28x28 pixels, and have one color channel
  input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

他们将-1放入批处理大小中,我在tensorflow教程中阅读了-1,当他们告诉计算机推断该尺寸时,他们使用了-1。我不明白的是在我们放入100之前,现在因为-1无法理解如何为模型输入批次大小,您能帮我解释一下吗?谢谢。

1 个答案:

答案 0 :(得分:0)

tl; dr

batch_size方法中的属性tf.reshape()batch_size函数中的tf.estimator.inputs.numpy_input_fn属性完全不同。

input_fn中的批量大小

方法batch_size的属性tf.estimator.inputs.numpy_input_fn控制在特定纪元(或时间实例)将训练或评估数据集中的观测值(或行或记录)的数量。因此,在提供的示例中,batch_size = 100表示将在每个时期通过学习算法训练数据集中的100行(在这种情况下为图像)。

重塑张量

方法tf.reshape用于更改张量的形状。方法tf.reshape具有属性(tensor, shape)。根据文档,shape属性具有一个特殊值-1,该值推断该特定轴的尺寸以保留总尺寸。因此,从提供的示例中,[-1, 28, 28, 1]转换为[batch_size, row, column, channel]batch_size为-1表示TensorFlow在将图像重塑为所有图像的784个输入要素(即28 * 28)的一维数组时,将保持图像的大小。