如何解决“仅张量变量支持切片分配”

时间:2019-01-08 07:11:27

标签: python tensorflow keras tensor

我正在尝试在keras中定义自己的自定义图层。在类的逻辑所在的调用函数中,我正在处理张量对象。

从张量对象的切细切片中找到最大值后,我想将其分配给其他张量,但出现错误

  

“仅变量支持切片分配”

我在无法解决问题的类的调用函数中尝试过Sess.eval()

mid_arr = x[i:spliti,j:splitj] #shredded slice
num = tf.reduce_max(mid_arr) #max vlaue from shred slice
res_arr = res_arr.assign( tf.where (res_arr[m][n],num, res_arr) ) #assign it

1 个答案:

答案 0 :(得分:1)

即使在注释部分(感谢jdehesa)中提供了解决方案,也请在此处(答案部分)指定解决方案,以获取社区的好处

完成2.x compatible code(变通)以执行 sliced assignment of a Tensor ,如下所示:

import tensorflow as tf
    
def replace_slice(input_, replacement, begin, size=None):
    inp_shape = tf.shape(input_)
    if size is None:
        size = tf.shape(replacement)
    else:
        replacement = tf.broadcast_to(replacement, size)
    padding = tf.stack([begin, inp_shape - (begin + size)], axis=1)
    replacement_pad = tf.pad(replacement, padding)
    mask = tf.pad(tf.ones_like(replacement, dtype=tf.bool), padding)
    return tf.where(mask, replacement_pad, input_)

def replace_slice_in(tensor):
    return _SliceReplacer(tensor)

class _SliceReplacer:
    def __init__(self, tensor):
        self._tensor = tensor
    def __getitem__(self, slices):
        return _SliceReplacer._Inner(self._tensor, slices)
    def with_value(self, replacement):  # Just for convenience in case you skip the indexing
        return _SliceReplacer._Inner(self._tensor, (...,)).with_value(replacement)
    class _Inner:
        def __init__(self, tensor, slices):
            self._tensor = tensor
            self._slices = slices
        def with_value(self, replacement):
            begin, size = _make_slices_begin_size(self._tensor, self._slices)
            return replace_slice(self._tensor, replacement, begin, size)

# This computes begin and size values for a set of slices
def _make_slices_begin_size(input_, slices):
    if not isinstance(slices, (tuple, list)):
        slices = (slices,)
    inp_rank = tf.rank(input_)
    inp_shape = tf.shape(input_)
    # Did we see a ellipsis already?
    before_ellipsis = True
    # Sliced dimensions
    dim_idx = []
    # Slice start points
    begins = []
    # Slice sizes
    sizes = []
    for i, s in enumerate(slices):
        if s is Ellipsis:
            if not before_ellipsis:
                raise ValueError('Cannot use more than one ellipsis in slice spec.')
            before_ellipsis = False
            continue
        if isinstance(s, slice):
            start = s.start
            stop = s.stop
            if s.step is not None:
                raise ValueError('Step value not supported.')
        else:  # Assumed to be a single integer value
            start = s
            stop = s + 1
        # Dimension this slice refers to
        i_dim = i if before_ellipsis else inp_rank - (len(slices) - i)
        dim_size = inp_shape[i_dim]
        # Default slice values
        start = start if start is not None else 0
        stop = stop if stop is not None else dim_size
        # Fix negative indices
        start = tf.cond(tf.convert_to_tensor(start >= 0), lambda: start, lambda: start + dim_size)
        stop = tf.cond(tf.convert_to_tensor(stop >= 0), lambda: stop, lambda: stop + dim_size)
        dim_idx.append([i_dim])
        begins.append(start)
        sizes.append(stop - start)
    # For empty slice specs like [...]
    if not dim_idx:
        return tf.zeros_like(inp_shape), inp_shape
    # Make full begin and size array (including omitted dimensions)
    begin_full = tf.scatter_nd(dim_idx, begins, [inp_rank])
    size_mask = tf.scatter_nd(dim_idx, tf.ones_like(sizes, dtype=tf.bool), [inp_rank])
    size_full = tf.where(size_mask,
                          tf.scatter_nd(dim_idx, sizes, [inp_rank]),
                          inp_shape)
    return begin_full, size_full

#with tf.Graph().as_default():
x = tf.reshape(tf.range(60), (4, 3, 5))
x2 = replace_slice_in(x)[:2, ..., -3:].with_value([100, 200, 300])

print('Tensor before Changing is \n', x)
print('\n')
print('Tensor after Changing is \n', x2)

以上代码的输出如下所示:

Tensor before Changing is 
 tf.Tensor(
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]

 [[30 31 32 33 34]
  [35 36 37 38 39]
  [40 41 42 43 44]]

 [[45 46 47 48 49]
  [50 51 52 53 54]
  [55 56 57 58 59]]], shape=(4, 3, 5), dtype=int32)


Tensor after Changing is 
 tf.Tensor(
[[[  0   1 100 200 300]
  [  5   6 100 200 300]
  [ 10  11 100 200 300]]

 [[ 15  16 100 200 300]
  [ 20  21 100 200 300]
  [ 25  26 100 200 300]]

 [[ 30  31  32  33  34]
  [ 35  36  37  38  39]
  [ 40  41  42  43  44]]

 [[ 45  46  47  48  49]
  [ 50  51  52  53  54]
  [ 55  56  57  58  59]]], shape=(4, 3, 5), dtype=int32)