ssim作为自动编码器中的自定义损失函数(keras或/和张量流)

时间:2018-07-04 11:00:31

标签: python tensorflow keras autoencoder

我目前正在为图像压缩编写自动编码器。从previous post中,我现在最终确认我不能在Keras和tensorflow中都不能将纯Python函数用作损失函数。 (而且我正在慢慢开始理解为什么;-)

我想用ssim作为损失函数和指标进行一些实验。现在看来我可能很幸运。张量流中已经有一个实现,请参见:https://www.tensorflow.org/api_docs/python/tf/image/ssim

tf.image.ssim(     img1,     img2,     max_val )

此外, bsautermeister 在此为stackoverflow提供了一个实现:SSIM / MS-SSIM for TensorFlow

我现在的问题是:如何使用mnist数据集将其用作损失函数?该函数不接受张量,而仅接受两个图像。而且,梯度会自动计算吗?据我了解,如果该功能是在tensorflow或keras后端中实现的。

对于一个最低限度的工作示例(MWE),我将不胜感激,该示例涉及如何在keras或tensorflow中使用任何上述ssim实现作为损失函数。

也许我们可以将MWE用于我之前的问题提供的自动编码器: keras custom loss pure python (without keras backend)

如果无法将我的keras自动编码器与ssim实现粘合在一起,那么可以直接在tensorflow中实现的自动编码器吗?我也有,可以提供吗?

我正在使用python 3.5,keras(具有tensorflow后端),并在必要时直接使用tensorflow。 目前,我正在使用mnist dataset(带有数字的那个)。

感谢您的帮助!

(附言:似乎有些人正在从事类似的工作。对此帖子的回答可能对Keras - MS-SSIM as loss function也很有用)

1 个答案:

答案 0 :(得分:1)

我无法与Keras一起使用,但是在普通的TensorFlow中,您只需切换L2或类似SSIM结果的任何费用

SELECT PRIMARY_KEY, TEXT_FIELD, SERIES_ID
FROM MY_TABLE 
ORDER BY SERIES_ID
OFFSET 0 FETCH NEXT 3 ROWS ONLY WITH TIES;

要直接检查操作是否具有渐变,请执行以下操作:

import tensorflow as tf
import numpy as np


def fake_img_batch(*shape):
    i = np.random.randn(*shape).astype(np.float32)
    i[i < 0] = -i[i < 0]
    return tf.convert_to_tensor(np.clip(i * 255, 0, 255))


fake_img_a = tf.get_variable('a', initializer=fake_img_batch(2, 224, 224, 3))
fake_img_b = tf.get_variable('b', initializer=fake_img_batch(2, 224, 224, 3))

fake_img_a = tf.nn.sigmoid(fake_img_a)
fake_img_b = tf.nn.sigmoid(fake_img_b)

# costs = tf.losses.mean_squared_error(fake_img_a, fake_img_b, reduction=tf.losses.Reduction.MEAN)
costs = tf.image.ssim(fake_img_a, fake_img_b, 1.)
costs = tf.reduce_mean(costs)

train = tf.train.AdamOptimizer(0.01).minimize(costs)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(costs))
    for k in range(500):
        _, l = sess.run([train, costs])
        if k % 100 == 0:
            print('mean SSIM', l)