自定义损失函数访问张量通道

时间:2020-03-07 21:40:54

标签: tensorflow keras

我有一个形状为(64、64、2)的2通道numpy数组作为输入到CNN。 我想按照https://www.tensorflow.org/guide/keras/train_and_evaluate中的描述构建自定义的损失函数:

def basic_loss_function(y_true, y_pred):
    return tf.math.reduce_mean(tf.abs(y_true - y_pred))

model.compile(optimizer=keras.optimizers.Adam(),
              loss=basic_loss_function)

model.fit(x_train, y_train, batch_size=64, epochs=3)

但是我想要比这个基本的东西更复杂的东西。我需要做一个逆DFT(ifft2d),并且我的y_pred和y_true的形状分别为(64,64,2),两个通道是fft2的实部和虚部。如何正确访问y_pred和y_true通道(我猜是某种keras /张量层?)以 RealPart + 1j * ImagPart 的形式重建复数(用numpy表示) y_pred [:,:,0]和y_pred [:,:,1])吗?

->总之,有人知道y_pred和y_true是哪种对象以及如何访问其通道/元素吗? (这不容易调试,因为它需要在已编译的CNN中运行,因此请事先了解一下)

1 个答案:

答案 0 :(得分:1)

y_truey_pred是形状为(batchsize, ...[output shape]...)的张量。您的输入的形状为(64,64,2),但我不确定您的输出是什么样子,如果您的输出确实为(64,64,2),则y_predy_true的形状为(64,64,64,2)给您您的batchsize=64

使用张量进行处理非常类似于numpy的语法,因此您可以使用带有张量的切片符号,例如y_true[:,:,:,0](请注意添加的批处理维度)。

Tensorflow具有用于计算DFT,FFT等的功能。参见tf.signaltf.signal.rfft2d

如果损失函数不仅涉及输出y_truey_pred的输入,还可以使用model.add_loss代替model.compile(loss= basic_loss_function),如下所示

x = Input(shape=(64,64,2))
y_true = Input(shape=...))
# your CNN layers
y_pred = Dense(128)(net)

model = Model(input=[x, y_true], output=y_pred)
model.add_loss(basic_loss_function(x, y_true, y_pred))

请注意,标签(又名y_true)现在是模型的输入。