以下代码仅从指定批处理大小的mnist中检索特定数字。但是返回的张量只是一个样本。有人可以看一下这个问题吗?谢谢。
import numpy as np
import tensorflow as tf
import sonnet as snt
class Input(snt.AbstractModule):
def __init__(self, batch_size, name = "input"):
super(Input, self).__init__(name = name)
mnist = tf.keras.datasets.mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
train_filter = np.where((Y_train == 0 ) | (Y_train == 1))
test_filter = np.where((Y_test == 0) | (Y_test == 1))
X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]
print(X_train.shape)
print(Y_train.shape)
with self._enter_variable_scope():
self._db_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
self._db_test = tf.data.Dataset.from_tensor_slices((X_test, Y_test))
self._db_train.repeat(-1)
self._db_test.repeat(-1)
self._db_train.batch(batch_size)
self._db_test.batch(batch_size)
# self._db_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
# self._db_test.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
self._it_train = self._db_train.make_one_shot_iterator()
self._it_test = self._db_test.make_one_shot_iterator()
def _build(self, is_training = True):
if is_training:
inputs, labels = self._it_train.get_next()
else:
inputs, labels = self._it_test.get_next()
return inputs, labels
def test():
input_ = Input(32)
inputs, labels = input_()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
inputs_val, labels_val = sess.run([inputs, labels])
print(inputs_val.shape)
print(labels_val.shape)
if __name__ == "__main__":
test()
以上代码段的输出如下:
(12665, 28, 28)
(12665,)
(28, 28)
()
请注意,我删除了不相关的内容。