数据加载器批的大小为0:ValueError:无法为形状为((64,64,64,3)'的Tensor'占位符:0'输入形状(0,)的值

时间:2018-11-06 17:02:01

标签: python tensorflow deep-learning generative-adversarial-network

您知道为什么real_batch的大小为0吗?

%matplotlib inline 
from matplotlib import pyplot as plt

tf.reset_default_graph()
LOGDIR = "logs"

def train(args):
    data_loader = Dataset(args.data_path, args.num_images, args.image_size) 
    print(args.dim_z)
    print(args.image_size)
    X = tf.placeholder(tf.float32, shape=[args.batch_size, args.image_size , args.image_size, 3])
    Z = tf.placeholder(tf.float32, shape=[args.batch_size, 1, 1, args.dim_z])
    G_sample, _ = generator(Z, args)
    print("size G_sample: ", G_sample)
    D_real, D_real_logits = discriminator(X, args, reuse=False)
    D_fake, D_fake_logits = discriminator(G_sample, args, reuse=True)
    d_loss, g_loss = get_losses(D_real_logits, D_fake_logits)
    z_sum = tf.summary.histogram('z', Z)
    d_sum = tf.summary.histogram('d', D_real)
    G_sum = tf.summary.histogram('g', G_sample)
    d_loss_sum = tf.summary.scalar('d_loss', d_loss)
    g_loss_sum = tf.summary.scalar('g_loss', g_loss)
    d_sum = tf.summary.merge([z_sum, d_sum, d_loss_sum])
    g_sum = tf.summary.merge([z_sum, G_sum, g_loss_sum])
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        writer = tf.summary.FileWriter('log', sess.graph)
        print(dir(data_loader))
        print(data_loader.num_imgs)
        for epoch in range(args.n_epoch):
            for itr, real_batch in enumerate(data_loader.get_nextbatch(args.batch_size)):
                sample = sess.run(G_sample, feed_dict={Z:sample_z(args.dim_z, args.batch_size)})
                print(type(real_batch))
                print(real_batch.shape)
                print(real_batch.size)
                print(sample.shape)
                #plt.imshow(sample.reshape(64,64, interpolation='nearest')
                #plt.show()
                #D_real = sess.run(D_real, feed_dict={X:X, args:args})
                D_real, D_real_logits = sess.run(discriminator, feed_dict={X:real_batch, args:args})
                D_fake, D_fake_logits = sess.run(discriminator, feed_dict={x:G_sample, args:args})
                d_loss, g_loss = sess.run(get_losses, feed_dict={d_real_logits:D_real_logits, d_fake_logits:D_fake_logits})
                d_optimizer, g_optimizer = get_optimizers(args.lr, args.beta1, args.beta2)
                d_step, g_step = optimize(d_optimizer, g_optimizer, d_loss, g_loss)
                writer = tf.summary.FileWriter(LOGDIR)   
                merged_summary = tf.summary.merge_all()
                writer.add_summary(merged_summary, itr)
                d_loss_summary = tf.summary.scalar("Discriminator_Total_Loss", d_loss)
                g_loss_summary = tf.summary.scalar("Generator_Total_Loss", g_loss)
                merged_summary = tf.summary.merge_all()
                writer.add_graph(sess.graph)
                saver.save(sess, save_path='./gan.ckpt')


train(args)   

错误是:

100
64
size G_sample:  Tensor("generator/last_layer/Tanh:0", shape=(64, 64, 64, 3), dtype=float32, device=/device:GPU:0)
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'data_path', 'denormalize_np_image', 'get_imagelist', 'get_input', 'get_nextbatch', 'load_and_preprocess_image', 'normalize_np_image', 'num_imgs', 'preprocess_and_save_images', 'show_image', 'target_imgsize']
202590
(6400,)
sample shape:  (64, 1, 1, 100)
<class 'numpy.ndarray'>
(0,)
0
(64, 64, 64, 3)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-54-6fe85c255384> in <module>()
     55 
     56 
---> 57 train(args)

<ipython-input-54-6fe85c255384> in train(args)
     40                 #plt.show()
     41                 #D_real = sess.run(D_real, feed_dict={X:X, args:args})
---> 42                 D_real, D_real_logits = sess.run(discriminator, feed_dict={X:real_batch, args:args})
     43                 D_fake, D_fake_logits = sess.run(discriminator, feed_dict={x:G_sample, args:args})
     44                 d_loss, g_loss = sess.run(get_losses, feed_dict={d_real_logits:D_real_logits, d_fake_logits:D_fake_logits})

/share/pkg/tensorflow/r1.10/install/py3-gpu/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    875     try:
    876       result = self._run(None, fetches, feed_dict, options_ptr,
--> 877                          run_metadata_ptr)
    878       if run_metadata:
    879         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/share/pkg/tensorflow/r1.10/install/py3-gpu/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1074                              'which has shape %r' %
   1075                              (np_val.shape, subfeed_t.name,
-> 1076                               str(subfeed_t.get_shape())))
   1077           if not self.graph.is_feedable(subfeed_t):
   1078             raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (0,) for Tensor 'Placeholder:0', which has shape '(64, 64, 64, 3)'

数据加载器为:

class Dataset(object):     
    def __init__(self, data_path, num_imgs, target_imgsize):
        self.data_path = data_path
        self.num_imgs = num_imgs 
        self.target_imgsize = target_imgsize 

    def normalize_np_image(self, image):
        return (image / 255.0 - 0.5) / 0.5

    def denormalize_np_image(self, image):
        return (image * 0.5 + 0.5) * 255

    def get_input(self, image_path):
        image = np.array(Image.open(image_path)).astype(np.float32)
        return self.normalize_np_image(image)

    def get_imagelist(self, data_path, celebA=False): 
        if celebA == True:
            imgs_path = os.path.join(data_path, 'img_align_celeba/*.jpg')
        else:
            imgs_path = os.path.join(data_path, '*.jpg') 
        all_namelist = glob.glob(imgs_path, recursive=True)
        return all_namelist[:self.num_imgs]

    def load_and_preprocess_image(self, image_path): 
        image = Image.open(image_path)
        j = (image.size[0] - 100) // 2
        i = (image.size[1] - 100) // 2
        image = image.crop([j, i, j + 100, i + 100])    
        image = image.resize([self.target_imgsize, self.target_imgsize], Image.BILINEAR)
        image = np.array(image.convert('RGB')).astype(np.float32)
        image = self.normalize_np_image(image)
        return image    

    #reads data, preprocesses and saves to another folder with the given path. 
    def preprocess_and_save_images(self, dir_name, save_path=''): 
        preproc_folder_path = os.path.join(save_path, dir_name)
        if not os.path.exists(preproc_folder_path):
            os.makedirs(preproc_folder_path)   
            imgs_path = os.path.join(self.data_path, 'img_align_celeba/*.jpg')
            print('Saving and preprocessing images ...')
            for num, imgname in enumerate(glob.iglob(imgs_path, recursive=True)):
                cur_image = self.load_and_preprocess_image(imgname)
                cur_image = Image.fromarray(np.uint8(self.denormalize_np_image(cur_image)))
                cur_image.save(preproc_folder_path + '/preprocessed_image_%d.jpg' %(num)) 
        self.data_path= preproc_folder_path

    def get_nextbatch(self, batch_size): 
        assert (batch_size > 0),"Give a valid batch size"
        cur_idx = 0
        image_namelist = self.get_imagelist(self.data_path)
        while cur_idx + batch_size <= self.num_imgs:
            cur_namelist = image_namelist[cur_idx:cur_idx + batch_size]
            cur_batch = [self.get_input(image_path) for image_path in cur_namelist]
            cur_batch = np.array(cur_batch).astype(np.float32)
            cur_idx += batch_size
            yield cur_batch

    def show_image(self, image, normalized=True):
        if not type(image).__module__ == np.__name__:
            image = image.numpy()
        if normalized:
            npimg = (image * 0.5) + 0.5 
        npimg.astype(np.uint8)
        plt.imshow(npimg, interpolation='nearest')

0 个答案:

没有答案