如何使用自定义损失函数处理批量大小错误?

时间:2021-01-06 04:48:36

标签: python tensorflow keras

由于大小错误,当 batch_size > 1 时,我无法训练自定义损失函数。

class API_Network(object):
    def __init__(self):
        self.model = self.build_model()
    
    def build_model(self):
        # AutoEncoder Model
        input_img = keras.Input((len(current_timestep_matrix), len(current_timestep_matrix[0]), len(current_timestep_matrix[0][0])))
    
        encoded = layers.Dense(8, activation='tanh')(input_img)
        encoded = layers.AveragePooling2D(pool_size=(2, 2))(encoded)
        encoded = layers.Dense(16, activation='tanh')(encoded)
        encoded = layers.AveragePooling2D(pool_size=(2, 2))(encoded)
        encoded = layers.Dense(1, activation='tanh')(encoded)
        encoded = layers.Flatten()(encoded)
    
        decoded = layers.Dense(len(output_train[0]), activation='linear')(encoded)

        model = keras.Model(input_img, decoded)
    
        model.compile(loss=my_forward_div_loss_fn(input_img), optimizer=keras.optimizers.Adam())
    
        model.summary()
        return model

def my_forward_div_loss_fn(input_img):
  def loss(y_true, y_pred):
        batch_size, width, height, channels = input_img.get_shape().as_list()

        u_field = K.abs(y_true - y_pred)
        v_field = input_img[:, :, :, 0]
        x_field = input_img[:, :, :, 1]
        y_field = input_img[:, :, :, 2]

        u_field = K.reshape(u_field, shape=[width, height])
        u_field = K.expand_dims(u_field, 0)

        forward_dudx = (u_field[batch_size, 1:-1, 2:] - u_field[batch_size, 1:-1, 1:-1]) / (x_field[batch_size, 1:-1, 2:] - x_field[batch_size, 1:-1, 1:-1])
        forward_dvdy = (v_field[batch_size, 1:-1, 1:-1] - v_field[batch_size, 2:, 1:-1]) / (y_field[batch_size, :-2, 1:-1] - y_field[batch_size, 2:, 1:-1])

        forward_divergence = forward_dudx + forward_dvdy  # COMPUTES FORWARD DIVERGENCE
        forward_divergence = tf.where(tf.math.is_nan(forward_divergence), tf.zeros_like(forward_divergence), forward_divergence)
        return K.square(forward_divergence)
    return loss

在训练时更改批大小参数(到batch_size=2)时,出现以下错误:

InvalidArgumentError: Input to reshape is a tensor with 55778 values, but the requested shape has 27889 [[{{node loss/dense_3_loss/Reshape}}]]

模型符合要求,但在训练网络时如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

更改这一行:

u_field = K.reshape(u_field, shape=[batch_size, width, height]) # add batch dimension