如何在Keras中为张量的2D子集分配另一个2D张量?

时间:2018-11-04 18:45:35

标签: python tensorflow keras

如果我有两个3D张量imggen。如何为img的2D子集分配gen的2D子集?由于tensorflow不允许直接分配张量,因此以下内容不起作用。

img[96:160 , 144:240 , :] = gen[96:160 , 144:240 , :]

编辑:

这是周围的代码。所以我使用了一个自定义的keras层。该层必须接收输入图像img和生成的图像x。它必须将img的一部分替换为x,并且必须返回修改后的img

def patcher(tensors):
    img = tensor[1]
    gen = tensor[0]
    #This is where the slicing must happen
    img[96:160 , 144:240 , :] = gen[96:160 , 144:240 , :]
    return [img]

img = Input( .. )
x = Conv( .. )(img)
out = Lambda(patcher,lambda a : [a[1]] )([x , img])
model = Model(img, out)

2 个答案:

答案 0 :(得分:1)

当前,您无法以简单的方式替换张量的切片。我实际上是opened an issue about it,因为这是人们一直在要求的东西。使用当前的API,您必须设法找出构建所需张量的最佳方法。在这种情况下,假设imggen都具有相同的形状,这是您可以这样做的一种方式:

import tensorflow as tf
import numpy as np

# Input
img = tf.placeholder(tf.float32, [None, None, None])
gen = tf.placeholder(tf.float32, [None, None, None])
row_start = tf.placeholder(tf.int32, [])
row_end = tf.placeholder(tf.int32, [])
col_start = tf.placeholder(tf.int32, [])
col_end = tf.placeholder(tf.int32, [])
# Masks rows and columns to be replaced
shape = tf.shape(img)
rows = shape[0]
cols = shape[1]
channels = shape[2]
i = tf.range(rows)
row_mask = (row_start <= i) & (i < row_end)
j = tf.range(cols)
col_mask = (col_start <= j) & (j < col_end)
# Full mask of replaced elements
mask = row_mask[:, tf.newaxis] & col_mask
# Select elements from flattened arrays
img_flat = tf.reshape(img, [-1, channels])
gen_flat = tf.reshape(gen, [-1, channels])
mask_flat = tf.reshape(mask, [-1])
result_flat = tf.where(mask_flat, gen_flat, img_flat)
# Reshape back
result = tf.reshape(result_flat, shape)

这是一个小测试:

with tf.Session() as sess:
    # img is positive and gen is negative
    img_val = np.arange(60).reshape((4, 5, 3))
    gen_val = -img_val
    # Do img[2:4, 0:3, :] = gen[2:4, 0:3, :]
    result_val = sess.run(result, feed_dict={
        img: img_val,
        gen: gen_val,
        row_start: 2,
        row_end: 4,
        col_start: 0,
        col_end: 3,
    })
    # Print one channel only for clarity
    print(result_val[:, :, 0])

输出:

[[  0.   3.   6.   9.  12.]
 [ 15.  18.  21.  24.  27.]
 [-30. -33. -36.  39.  42.]
 [-45. -48. -51.  54.  57.]]

编辑:

这是您发布的代码的可能实现。我在这里使用基于乘法的稍微不同的方法,当您有很多图像时,我认为这种方法会更好。

import tensorflow as tf

def replace_slices(img, gen, row_start, row_end, col_start, col_end):
    # Masks rows and columns to be replaced
    shape = tf.shape(img)
    rows = shape[1]
    cols = shape[2]
    i = tf.range(rows)
    row_mask = (row_start <= i) & (i < row_end)
    j = tf.range(cols)
    col_mask = (col_start <= j) & (j < col_end)
    # Full mask of replaced elements
    mask = row_mask[:, tf.newaxis] & col_mask
    # Add channel dimension to mask and cast
    mask = tf.cast(mask[:, :, tf.newaxis], img.dtype)
    # Compute result
    result = img * (1 - mask) + gen * mask
    return result

def patcher(tensors):
    img = tensor[1]
    gen = tensor[0]
    img = replace_slices(img, gen, 96, 160, 144, 240)
    return [img]

img = Input( .. )
x = Conv( .. )(img)
out = Lambda(patcher, ambda a: [a[1]])([x , img])
model = Model(img, out)

答案 1 :(得分:1)

我修改了最初的解决方案,该解决方案仅在基于@jdehesa的解决方案设置了批次大小时才起作用。这应该适用于所有后端(TensorFlow,Theano和CNTK)的Keras:

from keras import backend as K
import numpy as np

def replace_slices(ts, row_start, row_end, col_start, col_end):
    shape = K.int_shape(ts[0])[1:-1]
    np_mask = np.zeros(shape + (1,))
    np_mask[row_start:row_end, col_start:col_end] = 1.
    mask = K.variable(np_mask, dtype=K.dtype(ts[0]))
    # ts[0] is the img and ts[1] is the x tensor
    return ts[0] * (1 - mask) + ts[1] * mask

args = {'row_start': 96, 'row_end': 160, 'col_start': 144, 'col_end': 240}

img = Input(shape=(256,384,3))
x = Conv2D(3, (3,3), padding='same')(img) # this must have 3 filters since img has 3 channels
out = Lambda(replace_slices, arguments=args)([img, x])
model = Model(img, out)