在Tensorflow中实现扁平化层

时间:2020-06-26 03:11:28

标签: python tensorflow keras keras-layer

我正在尝试使用TensorFlow 2.2.0实现扁平化层。我正在按照Geron的书(第二版)中的说明进行操作。 至于平坦层,我首先尝试获取批输入形状并计算新形状。 但是我已经用张量维度解决了这个问题:TypeError: Dimension value must be integer or None or have an __index__ method

import tensorflow as tf
from tensorflow import keras
(X_train, y_train), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
input_shape = X_train.shape[1:]
assert input_shape == (28, 28)

class MyFlatten(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape) 

    def call(self, X):
        X_shape = tf.shape(X)
        batch_size = X_shape[0]
        new_shape = tf.TensorShape([batch_size, X_shape[1]*X_shape[2]])
        return tf.reshape(X, new_shape)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

## works fine on this example
MyFlatten()(X_train[:10])

## fail when building a model
input_ = keras.layers.Input(shape=[28, 28])
fltten_ = MyFlatten()(input_)
hidden1 = keras.layers.Dense(300, activation="relu")(fltten_)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
output = keras.layers.Dense(10, activation="softmax")(hidden2)
model = keras.models.Model(inputs=[input_], outputs=[output])
model.summary()

1 个答案:

答案 0 :(得分:0)

不要尝试创建tf.TensorShape,它仅在张量的所有维度已知时才起作用,实际上,张量的所有维度仅在渴望模式下进行,因此模型编译将失败。只需这样重塑:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, X_shape[1] * X_shape[2]]
    return tf.reshape(X, new_shape)

或者,更一般而言,您可以这样做:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, tf.math.reduce_prod(X_shape[1:])]
    return tf.reshape(X, new_shape)

tf.reshape也将接受类似new_shape = [batch_size, -1]的内容,但我认为,视情况而定,可能会使扁平化尺寸的大小未知。另一方面,相反的new_shape = [-1, tf.math.reduce_prod(X_shape[1:])]也应该可以正常工作。

顺便说一句,我想您是在做练习,并且已经知道这一点,但是仅供参考,Keras中已经有一个Flatten层(您可以检查其source code)。