我正在归类tf.keras.model
。我需要覆盖compute_output_shape
,否则,我将得到here中的NotImplementedError
。
class Custom(tf.keras.Model):
...
def compute_output_shape(self, input_shape):
# input_shape = (None, ...)
batch_size = ???
return (batch_size, ...)
compute_output_shape
以input_shape
作为输入。但是,这并没有太大帮助,因为批量大小以某种方式在TensorFlow中丢失了。
如果我尝试以与None
相同的方式返回以input_shape
开头的形状,则会得到TypeError: 'str' object cannot be interpreted as an integer
。只是省略批处理大小也不起作用。
批处理大小是可变的,所以我不能仅仅对其进行硬编码。