Tensorflow-获取像素邻域

时间:2018-08-09 12:33:45

标签: python tensorflow keras loss-function

我正在尝试实现Tensorflow / Keras中Levin等(2004)的经典图像着色论文的损失功能:

main_loss

这是权重方程(强度之间的相关性):

weights

y是3x3窗口中x的每个相邻像素,而w是这些像素中每个像素的权重。

权重要求计算每个像素邻域的均值和方差。

我找不到允许我以符号方式编写此损失函数的函数,并且我认为应该在为每个窗口计算w的循环中编写它。 / p>

如何在Tensorflow中以符号方式或循环方式编写此Loss函数?

非常感谢。

编辑:这是我在Numpy中计算权重的代码:

import cv2
import numpy as np

im = cv2.resize(cv2.imread('./Image.jpg', 0), (256, 256)) / np.float32(255.0)

M = 3
N = 3

# Split the image into 3x3 windows
windows = [im[x:x + M, y:y + N] for x in range(0, im.shape[0], M) for y in range(0, im.shape[1], N)]

# Calculate the correlation for each window
weights = [1 + np.corrcoef(tile) for tile in windows]

2 个答案:

答案 0 :(得分:2)

我认为这段代码可以计算您公式中的值:

import tensorflow as tf
from itertools import product

SIGMA = 1.0

dtype = tf.float32
# Input images batch
img = tf.placeholder(dtype, [None, None, None])
img_shape = tf.shape(img)
img_height = img_shape[1]
img_width = img_shape[2]
# Compute 3 x 3 block means
mean_filter = tf.ones((3, 3), dtype) / 9
img_mean = tf.nn.conv2d(img[:, :, :, tf.newaxis],
                        mean_filter[:, :, tf.newaxis, tf.newaxis],
                        [1, 1, 1, 1], 'VALID')[:, :, :, 0]
# Remove 1px border
img_clip = img[:, 1:-1, 1:-1]
# Difference between pixel intensity and its block mean
x_diff = img_clip - img_mean
# Compute neighboring pixel loss contributions
contributions = []
for i, j in product((-1, 0, 1), repeat=2):
    if i == j == 0: continue
    # Take "shifted" image
    displaced_img = img[:, 1 + i:img_width - 1 + i, 1 + j:img_height - 1 + j]
    # Compute difference with mean of corresponding pixel block
    y_diff = displaced_img - img_mean
    # Weights formula
    weight = 1 + x_diff * y_diff / (SIGMA ** 2)
    # Contribution of this displaced image to the loss of each pixel
    contribution = weight * displaced_img
    contributions.append(contribution)
contributions = tf.add_n(contributions)
# Compute loss value
loss = tf.reduce_sum(tf.squared_difference(img_clip, contributions))

由于原则上没有在公式中很好地定义,因此未计算沿图像边界的像素损失,尽管您可以进行一些更改以将它们考虑在内(将卷积更改为“'SAME' “,请在必要的地方垫上,等等。)

答案 1 :(得分:0)

这是3 x 3窗口的均方误差。对? 听起来像是用于纹理分析的GLCM矩阵,您是否要对图像中的每3x3窗口应用此损失函数?

我认为最好在Numpy中使用Random weight来构建进行此计算的函数,因此在尝试使用TF进行构建以尝试优化之后。