Tensorflow中的外部损失函数

时间:2018-11-09 06:35:02

标签: python tensorflow tensor

我有一个简单的2层密集NN,我想使用回归模型来计算图像的给定~ 700个特征的4个。不幸的是,我没有基本事实要素,因此我使用自定义损失函数。这是函数的来源:

def loss_function(logits, img, g, compare_img):

    final_img = img_pipeline(vga_8b=img, g=g%external color gamma function%)

    with tf.name_scope('Loss'):
        loss = score(gt_image=compare_img, curr_img=final_img)
        return loss

其中logits是当前评估的4个数字,g只是一个插值函数,用作图像的颜色伽玛,img是外部灰度图像,用于生成用于score的最终结果图像功能。 compare_img不是地面真实图像,而是在得分函数中使用的一些统计值(kept in python dict)用于评估当前产生的图像。 不幸的是,我无法提供gcompare_img,因为它们是无法转换为张量的python函数和python字典。

有没有办法以某种方式破解它并达到预期的结果?

提前谢谢!

1 个答案:

答案 0 :(得分:1)

您可以在tf.map中使用带有tensorflow的外部函数,但我要说的是,这些函数无法通过它计算梯度。但是损失函数在每种情况下都必须是可导出的。因此,您必须在tensorflow中编写函数。

对于dict值,您可以使用

创建一个查找表

table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)