为ModelNet10设置3D-GAN时获得以下错误消息:
InvalidArgumentError:要重塑的输入是具有27000个值的张量,但请求的形状具有810000 [Op:Reshape]
我认为该批次未正确创建,因此张量的形状无效。尝试了不同的方法,但无法设置批次。 我非常感谢您有任何清理代码的提示! 预先感谢!
import time
import numpy as np
import tensorflow as tf
np.random.seed(1)
from tensorflow.keras import layers
from IPython import display
# Load the data
modelnet_path = '/modelnet10.npz'
data = np.load(modelnet_path)
X, Y = data['X_train'], data['y_train']
X_test, Y_test = data['X_test'], data['y_test']
X = X.reshape(X.shape[0], 30, 30, 30, 1).astype('float32')
#Hyperparameters
BUFFER_SIZE = 3991
BATCH_SIZE = 30
LEARNING_RATE = 4e-4
BETA_1 = 5e-1
EPOCHS = 100
#Random seed for image generation
n_examples = 16
noise_dim = 100
seed = tf.random.normal([n_examples, noise_dim])
train_dataset = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
# Build the network
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Reshape((30, 30, 30, 1), input_shape=(30, 30, 30)))
model.add(layers.Conv3D(16, 6, strides=2, activation='relu'))
model.add(layers.Conv3D(64, 5, strides=2, activation='relu'))
model.add(layers.Conv3D(64, 5, strides=2, activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(10))
return model
discriminator = make_discriminator_model()
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(15*15*15*128, use_bias=False,input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.Reshape((15,15,15,128)))
model.add(layers.Conv3DTranspose(64, (5,5,5), strides=(1,1,1), padding='valid', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.Conv3DTranspose(32, (5,5,5), strides=(2,2,2), padding='valid', use_bias=False, activation='tanh'))
return model
generator = make_generator_model()
#Optimizer & Loss function
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
optimizer = tf.keras.optimizers.Adam(lr=LEARNING_RATE, beta_1=BETA_1)
#Training
def train_step(shapes):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_shapes = generator(noise, training=True)
real_output = discriminator(shapes, training=True)
fake_output = discriminator(generated_shapes, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for shape_batch in dataset:
train_step(shape_batch)
display.clear_output(wait=True)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
display.clear_output(wait=True)
train(X_test, EPOCHS)
答案 0 :(得分:0)
X_test只是一个列表,因此在您的训练循环中,只有一个样本(30 * 30 * 30 = 27000)进入模型,但是模型本身要求30(批大小)* 30 * 30 * 30 = 810000。
modelnet_path = '/modelnet10.npz'
data = np.load(modelnet_path)
X, Y = data['X_train'], data['y_train']
X_test, Y_test = data['X_test'], data['y_test']
X = X.reshape(X.shape[0], 30, 30, 30, 1).astype('float32')
...
train_dataset = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
...
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for shape_batch in dataset:
train_step(shape_batch)
display.clear_output(wait=True)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
display.clear_output(wait=True)
train(X_test, EPOCHS)
考虑使用您创建的 train_dataset 进行训练,或将X_test生成为tf.dataset。
train(train_dataset , EPOCHS)