Tensorflow输入形状不兼容

时间:2020-09-14 18:15:12

标签: python tensorflow

尝试构建一个Tensorflow模型,其中我的数据具有70个功能。这是我的第一层的设置:

tf.keras.layers.Dense(units=50, activation='relu', input_shape=(None,70)),

将输入形状设置为(None,70)在我看来是最好的,因为我正在使用前馈神经网络,其中每个“行”数据都是唯一的。我正在使用大小为10的批处理大小(现在),我的输入形状是否应更改为(10,70)

我尝试使用原始的(None, 70)并收到错误消息:

WARNING:tensorflow:Model was constructed with shape (None, None, 70) for input Tensor("dense_33_input:0", shape=(None, None, 70), dtype=float32), but it was called on an input with incompatible shape (10, 70).

TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.

对于input_shape到底出了什么问题,颇有些困惑,因为(None, 70)似乎最合适。非常感谢您的帮助。

编辑:想添加一个可重现的示例以获得更多上下文。对不起,长度。这是对[此示例] [1]的复制,以更好地适合我的当前数据(非图像)。

可变自动编码器模型

class VAE(tf.keras.Model):
    
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units=50, activation='relu', input_shape=(70,)),
            tf.keras.layers.Dense(latent_dim + latent_dim), #No activation
        ])
        
        self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units=50, activation='relu', input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=70),
        ])
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits


  [1]: https://www.tensorflow.org/tutorials/generative/cvae

优化器和损失Funx

optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), axis=raxis)

def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1])
    logpz = log_normal_pdf(z, 0, 0)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

培训

X = tf.random.uniform((100,70))
y = tf.random.uniform((100,))

ds_train = tf.data.Dataset.from_tensor_slices((X, y))

tf.random.set_seed(1)

train = ds_train.shuffle(buffer_size=len(X))
train = train.batch(batch_size=10, drop_remainder=False)

epochs = 5
latent_dim = 2

model = VAE(2)

for epoch in range(1, epochs+1):
    start_time = time.time()
    for i, (train_x, train_y) in enumerate(train):
        train_step(model, train_x, optimizer)
    end_time = time.time()
    
    loss = tf.keras.metrics.Mean()
    for i, (test_x, test_y) in enumerate(ds_test):
        loss(compute_loss(model, test_x))
    elbo = -loss.result()
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
         .format(epoch, elbo, end_time - start_time))

1 个答案:

答案 0 :(得分:1)

input_shape不应包含批次尺寸。使用input_shape=(70,)

tf.keras.layers.Dense(units=50, activation='relu', input_shape=(70,))

您可以在致电model.fit(..., batch_size=10)时设置批处理大小。请参阅tf.keras.Model.fit上的文档。

由于将int32值传递给tf.math.exp,因此原始帖子中出现了另一个错误。该行应显示为

logpz = log_normal_pdf(z, 0., 0.)

解决该错误。请注意0.值,该值计算为浮点数而不是整数。