您知道为什么real_batch的大小为0吗?
%matplotlib inline
from matplotlib import pyplot as plt
tf.reset_default_graph()
LOGDIR = "logs"
def train(args):
data_loader = Dataset(args.data_path, args.num_images, args.image_size)
print(args.dim_z)
print(args.image_size)
X = tf.placeholder(tf.float32, shape=[args.batch_size, args.image_size , args.image_size, 3])
Z = tf.placeholder(tf.float32, shape=[args.batch_size, 1, 1, args.dim_z])
G_sample, _ = generator(Z, args)
print("size G_sample: ", G_sample)
D_real, D_real_logits = discriminator(X, args, reuse=False)
D_fake, D_fake_logits = discriminator(G_sample, args, reuse=True)
d_loss, g_loss = get_losses(D_real_logits, D_fake_logits)
z_sum = tf.summary.histogram('z', Z)
d_sum = tf.summary.histogram('d', D_real)
G_sum = tf.summary.histogram('g', G_sample)
d_loss_sum = tf.summary.scalar('d_loss', d_loss)
g_loss_sum = tf.summary.scalar('g_loss', g_loss)
d_sum = tf.summary.merge([z_sum, d_sum, d_loss_sum])
g_sum = tf.summary.merge([z_sum, G_sum, g_loss_sum])
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('log', sess.graph)
print(dir(data_loader))
print(data_loader.num_imgs)
for epoch in range(args.n_epoch):
for itr, real_batch in enumerate(data_loader.get_nextbatch(args.batch_size)):
sample = sess.run(G_sample, feed_dict={Z:sample_z(args.dim_z, args.batch_size)})
print(type(real_batch))
print(real_batch.shape)
print(real_batch.size)
print(sample.shape)
#plt.imshow(sample.reshape(64,64, interpolation='nearest')
#plt.show()
#D_real = sess.run(D_real, feed_dict={X:X, args:args})
D_real, D_real_logits = sess.run(discriminator, feed_dict={X:real_batch, args:args})
D_fake, D_fake_logits = sess.run(discriminator, feed_dict={x:G_sample, args:args})
d_loss, g_loss = sess.run(get_losses, feed_dict={d_real_logits:D_real_logits, d_fake_logits:D_fake_logits})
d_optimizer, g_optimizer = get_optimizers(args.lr, args.beta1, args.beta2)
d_step, g_step = optimize(d_optimizer, g_optimizer, d_loss, g_loss)
writer = tf.summary.FileWriter(LOGDIR)
merged_summary = tf.summary.merge_all()
writer.add_summary(merged_summary, itr)
d_loss_summary = tf.summary.scalar("Discriminator_Total_Loss", d_loss)
g_loss_summary = tf.summary.scalar("Generator_Total_Loss", g_loss)
merged_summary = tf.summary.merge_all()
writer.add_graph(sess.graph)
saver.save(sess, save_path='./gan.ckpt')
train(args)
错误是:
100
64
size G_sample: Tensor("generator/last_layer/Tanh:0", shape=(64, 64, 64, 3), dtype=float32, device=/device:GPU:0)
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'data_path', 'denormalize_np_image', 'get_imagelist', 'get_input', 'get_nextbatch', 'load_and_preprocess_image', 'normalize_np_image', 'num_imgs', 'preprocess_and_save_images', 'show_image', 'target_imgsize']
202590
(6400,)
sample shape: (64, 1, 1, 100)
<class 'numpy.ndarray'>
(0,)
0
(64, 64, 64, 3)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-54-6fe85c255384> in <module>()
55
56
---> 57 train(args)
<ipython-input-54-6fe85c255384> in train(args)
40 #plt.show()
41 #D_real = sess.run(D_real, feed_dict={X:X, args:args})
---> 42 D_real, D_real_logits = sess.run(discriminator, feed_dict={X:real_batch, args:args})
43 D_fake, D_fake_logits = sess.run(discriminator, feed_dict={x:G_sample, args:args})
44 d_loss, g_loss = sess.run(get_losses, feed_dict={d_real_logits:D_real_logits, d_fake_logits:D_fake_logits})
/share/pkg/tensorflow/r1.10/install/py3-gpu/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
875 try:
876 result = self._run(None, fetches, feed_dict, options_ptr,
--> 877 run_metadata_ptr)
878 if run_metadata:
879 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/share/pkg/tensorflow/r1.10/install/py3-gpu/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1074 'which has shape %r' %
1075 (np_val.shape, subfeed_t.name,
-> 1076 str(subfeed_t.get_shape())))
1077 if not self.graph.is_feedable(subfeed_t):
1078 raise ValueError('Tensor %s may not be fed.' % subfeed_t)
ValueError: Cannot feed value of shape (0,) for Tensor 'Placeholder:0', which has shape '(64, 64, 64, 3)'
数据加载器为:
class Dataset(object):
def __init__(self, data_path, num_imgs, target_imgsize):
self.data_path = data_path
self.num_imgs = num_imgs
self.target_imgsize = target_imgsize
def normalize_np_image(self, image):
return (image / 255.0 - 0.5) / 0.5
def denormalize_np_image(self, image):
return (image * 0.5 + 0.5) * 255
def get_input(self, image_path):
image = np.array(Image.open(image_path)).astype(np.float32)
return self.normalize_np_image(image)
def get_imagelist(self, data_path, celebA=False):
if celebA == True:
imgs_path = os.path.join(data_path, 'img_align_celeba/*.jpg')
else:
imgs_path = os.path.join(data_path, '*.jpg')
all_namelist = glob.glob(imgs_path, recursive=True)
return all_namelist[:self.num_imgs]
def load_and_preprocess_image(self, image_path):
image = Image.open(image_path)
j = (image.size[0] - 100) // 2
i = (image.size[1] - 100) // 2
image = image.crop([j, i, j + 100, i + 100])
image = image.resize([self.target_imgsize, self.target_imgsize], Image.BILINEAR)
image = np.array(image.convert('RGB')).astype(np.float32)
image = self.normalize_np_image(image)
return image
#reads data, preprocesses and saves to another folder with the given path.
def preprocess_and_save_images(self, dir_name, save_path=''):
preproc_folder_path = os.path.join(save_path, dir_name)
if not os.path.exists(preproc_folder_path):
os.makedirs(preproc_folder_path)
imgs_path = os.path.join(self.data_path, 'img_align_celeba/*.jpg')
print('Saving and preprocessing images ...')
for num, imgname in enumerate(glob.iglob(imgs_path, recursive=True)):
cur_image = self.load_and_preprocess_image(imgname)
cur_image = Image.fromarray(np.uint8(self.denormalize_np_image(cur_image)))
cur_image.save(preproc_folder_path + '/preprocessed_image_%d.jpg' %(num))
self.data_path= preproc_folder_path
def get_nextbatch(self, batch_size):
assert (batch_size > 0),"Give a valid batch size"
cur_idx = 0
image_namelist = self.get_imagelist(self.data_path)
while cur_idx + batch_size <= self.num_imgs:
cur_namelist = image_namelist[cur_idx:cur_idx + batch_size]
cur_batch = [self.get_input(image_path) for image_path in cur_namelist]
cur_batch = np.array(cur_batch).astype(np.float32)
cur_idx += batch_size
yield cur_batch
def show_image(self, image, normalized=True):
if not type(image).__module__ == np.__name__:
image = image.numpy()
if normalized:
npimg = (image * 0.5) + 0.5
npimg.astype(np.uint8)
plt.imshow(npimg, interpolation='nearest')