无法获得正确批大小的张量

时间:2018-11-05 11:55:26

标签: python tensorflow tensorflow-datasets

以下代码仅从指定批处理大小的mnist中检索特定数字。但是返回的张量只是一个样本。有人可以看一下这个问题吗?谢谢。

import numpy as np
import tensorflow as tf
import sonnet as snt

class Input(snt.AbstractModule):
    def __init__(self, batch_size, name = "input"):
        super(Input, self).__init__(name = name)

        mnist = tf.keras.datasets.mnist

        (X_train, Y_train), (X_test, Y_test) = mnist.load_data()

        train_filter = np.where((Y_train == 0 ) | (Y_train == 1))
        test_filter = np.where((Y_test == 0) | (Y_test == 1))

        X_train, Y_train = X_train[train_filter], Y_train[train_filter]
        X_test, Y_test = X_test[test_filter], Y_test[test_filter]

        print(X_train.shape)
        print(Y_train.shape)

        with self._enter_variable_scope():
            self._db_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
            self._db_test = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

            self._db_train.repeat(-1)
            self._db_test.repeat(-1)

            self._db_train.batch(batch_size)
            self._db_test.batch(batch_size)

            # self._db_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
            # self._db_test.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

            self._it_train = self._db_train.make_one_shot_iterator()
            self._it_test = self._db_test.make_one_shot_iterator()


    def _build(self, is_training = True):

        if is_training:
            inputs, labels = self._it_train.get_next()
        else:
            inputs, labels = self._it_test.get_next()

        return inputs, labels
def test():
    input_ = Input(32)

    inputs, labels = input_()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        inputs_val, labels_val = sess.run([inputs, labels])


        print(inputs_val.shape)
        print(labels_val.shape)


if __name__ == "__main__":
    test()

以上代码段的输出如下:

(12665, 28, 28)
(12665,)
(28, 28)
()

请注意,我删除了不相关的内容。

0 个答案:

没有答案