Tensorflow MNIST数据加载器提供互连的numpy数组

时间:2019-04-18 11:28:30

标签: tensorflow mnist

我正在使用内置的Tensorflow数据集模块读取一批MNIST数据。这给出了一个numpy数组作为批处理。但是,如果我将数组复制到另一个变量并对该第二个变量进行更改,则原始批处理数组也会更改。 我对为什么原始数组和复制的数组之间没有任何联系感到怀疑。

您可以在此CoLab链接上进行测试:

https://colab.research.google.com/drive/1DN4n5_YCO33LozxtidM7STqEAUWypNOv

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def test_reconstruction(mnist, h=28, w=28, batch_size=100):
    # Test the trained model: reconstruction
    batch = mnist.test.next_batch(batch_size)
    batch_clean = batch[0]

    print('before damage:', np.mean(batch_clean))
    batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))
    tmp = batch_damaged
    tmp[:, 10:20, 10:20] = 0
    print('after damage:', np.mean(batch_clean))

test_reconstruction(mnist)

预期:两个打印语句应返回相同的平均值

实际:两个打印语句的平均值不同

1 个答案:

答案 0 :(得分:0)

在您的行var all = ( from i in ctx.Image join c in ctx.Camera on i.CameraId equals c.CameraId select new { i, c } ).ToArray(); 中,复制batch_clean的引用而不是其值。您应该使用batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))返回数组的副本。 numpy.copy