如何从npy文件中读取数据作为TensorFlow数据集

时间:2018-10-29 12:03:24

标签: numpy tensorflow

我有npy包含两个ndarray,一个是一个热点,另一个是标签。如何输入数据?

如何将这些数据转换为张量x_train?我只对MNIST数据集有经验,而不必处理输入数据。

1 个答案:

答案 0 :(得分:0)

.npy 数组而言,

MNIST数据集的加载方式类似。

这行代码

(X_train, y_train), (X_test, y_test) = 

 keras.datasets.mnist.load_data(path="C:/Users/476458/.keras/datasets/mnist.npz")

依次使用以下代码。

path = get_file(path,
                origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)

np.load 加载 .npz 文件,该文件是一个压缩的归档文件,其中包含4个 .npy 文件(x_train.npy,y_train。 npy,x_test.npy,y_test)。

但是您应该可以直接使用 np.load 加载 .npy 文件。

我希望这会给您一些想法。