我正在编写一个自定义的Keras层,该层会平整除输入的最后一个维度以外的所有维度。但是,当将层的输出馈送到下一层时,会发生错误,因为该层的输出形状在所有维度上均为None
。
class FlattenLayers( Layer ):
"""
Takes a nD tensor flattening the middle dimensions, ignoring the edge dimensions.
[ n, x, y, z ] -> [ n, xy, z ]
"""
def __init__( self, **kwargs ):
super( FlattenLayers, self ).__init__( **kwargs )
def build( self, input_shape ):
super( FlattenLayers, self ).build( input_shape )
def call( self, inputs ):
input_shape = tf.shape( inputs )
flat = tf.reshape(
inputs,
tf.stack( [
-1,
K.prod( input_shape[ 1 : -1 ] ),
input_shape[ -1 ]
] )
)
return flat
def compute_output_shape( self, input_shape ):
if not all( input_shape[ 1: ] ):
raise ValueError( 'The shape of the input to "Flatten" '
'is not fully defined '
'(got ' + str( input_shape[ 1: ] ) + '). '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.' )
output_shape = (
input_shape[ 0 ],
np.prod( input_shape[ 1 : -1 ] ),
input_shape[ -1 ]
)
return output_shape
例如,当紧跟一层时,我会收到错误ValueError: The last dimension of the inputs to Dense should be defined. Found None.
答案 0 :(得分:1)
为什么您的tf.stack()
具有新的形状?您想要展平除最后一个尺寸外的所有尺寸;这是你怎么做的:
import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np
class FlattenLayer(Layer):
def __init__( self, **kwargs):
super(FlattenLayer, self).__init__(**kwargs)
def build( self, input_shape ):
super(FlattenLayer, self).build(input_shape)
def call( self, inputs):
new_shape = self.compute_output_shape(tf.shape(inputs))
return tf.reshape(inputs, new_shape)
def compute_output_shape(self, input_shape):
new_shape = (input_shape[0]*input_shape[1]*input_shape[2],
input_shape[3])
return new_shape
使用单个数据点(tf.__version__=='1.13.1'
)进行测试:
inputs = tf.keras.layers.Input(shape=(10, 10, 1))
res = tf.keras.layers.Conv2D(filters=3, kernel_size=2)(inputs)
res = FlattenLayer()(res)
model = tf.keras.models.Model(inputs=inputs, outputs=res)
x_data = np.random.normal(size=(1, 10, 10, 1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
evaled = model.outputs[0].eval({model.inputs[0]:x_data})
print(evaled.shape) # (81, 3)