为什么我们在将数据输入tensorflow之前将其展平?

时间:2017-06-15 16:04:00

标签: tensorflow deep-learning mnist

我正在关注udacity MNIST tutorial,而MNIST数据最初是(784 = 28 * 28)矩阵。然而,在提供数据之前,他们将数据展平为1d数组,其中包含784列/var/log/httpd/access_log

例如, 原始训练集形状为(200000,28,28) 200000行(数据)。每个数据是28 * 28矩阵

他们将此转换为形状为(200000,784)

的训练集

有人可以解释为什么他们在输入tensorflow之前将数据压扁了吗?

2 个答案:

答案 0 :(得分:4)

因为当您添加完全连接的图层时,您总是希望您的数据是(1或)2维矩阵,其中每一行都是表示数据的向量。这样,完全连接的图层只是输入(大小为(batch_size, n_features))和权重(形状(n_features, n_outputs))(加上偏差和激活函数)之间的矩阵乘法,你会得到一个形状(batch_size, n_outputs)的输出。另外,你真的不需要完全连接层中的原始形状信息,所以可以丢失它。

在不重新整形的情况下获得相同结果会更复杂,效率更低,这就是为什么我们总是在完全连接的层之前完成它。对于卷积层,相反,您希望将数据保持原始格式(宽度,高度)。

答案 1 :(得分:2)

这是一个完全连接的层的约定。完全连接的层将前一层中的每个节点与连续层中的每个节点连接起来,因此对于这种类型的层,局部性不是问题。

此外,通过定义这样的图层,我们可以通过计算公式有效地计算下一步:f(Wx + b) = y。这对于多维输入来说并不是那么容易,并且重塑输入是低成本且易于实现的。