TensorFlow:替代tf.scatter_update

时间:2019-02-28 17:12:51

标签: python tensorflow

我有两个这样的张量:

template = tf.convert_to_tensor([[1, 0, 0.5, 0.5, 0.3, 0.3],                                                                                                                                                                                                      
                                 [1, 0, 0.75, 0.5, 0.3, 0.3],                                                                                                                                                                                                     
                                 [1, 0, 0.5, 0.75, 0.3, 0.3],                                                                                                                                                                                                     
                                 [1, 0, 0.75, 0.75, 0.3, 0.3]])                                                                                                                                                                                                   

patch = tf.convert_to_tensor([[0, 1, 0.43, 0.17, 0.4, 0.4],                                                                                                                                                                                                       
                              [0, 1, 0.18, 0.22, 0.53, 0.6]])

现在,我想用template行来更新patch的第二行和最后一行,以得到像这样的值:

[[1.   0.   0.5  0.5  0.3  0.3 ]
 [0.   1.   0.43 0.17 0.4  0.4 ]
 [1.   0.   0.5  0.75 0.3  0.3 ]
 [0.   1.   0.18 0.22 0.53 0.6 ]]

使用tf.scatter_update很容易:

var_template = tf.Variable(template)                                                                                                                                                                                                                              
var_template = tf.scatter_update(var_template, [1, 3], patch) 

但是,它需要创建一个变量。有没有一种方法可以仅使用张量操作来获取值?

我当时在考虑tf.where,但是随后我可能不得不将每个补丁行广播到模板大小中,并为每一行调用tf.where

2 个答案:

答案 0 :(得分:1)

这应该有效。有点扭曲,但未使用任何变量。

function insertElem(numberOfElems, elemTag, elemId, elemClass, parentSelector, elemSrc){
    /*
    * numberOfElements:-    Pass in a whole integer.
    * elemTag:-             Pass in the element tag type (inside "" or '').              
    * elemId:-              Pass in a name for the element id (inside "" or ''),
                            an integer is appended to the id name by the for loop.
    * elemClass:-           Pass in a name for element class (inside "" or '').
    * parentSelector:-      Pass in the identifier of the parent element (inside "" or '')
                            * querySelector prefixes:    # = id
                                                         . = class
                                                         none = tag 
    *elemSrc:-              Pass in the source media url (inside "" or ''). 
    */      
      if (numberOfElems > 1) {
        for (i = 0; i < numberOfElems; i++) {
          var elem = this[elemId + i];
          elem = document.createElement(elemTag);
          elem.id = elemId + '_' + i;
             if (elemClass) {
                elem.className = elemClass;
             }
          parentEl = document.querySelector(parentSelector);
          parentEl.appendChild(elem);
            if(elem instanceof HTMLMediaElement) {
                elem.src = elemSrc;
            }
        }
      } else {
        var elem = this[elemId];
        elem = document.createElement(elemTag);
        elem.id = elemId;
            if (elemClass) {
                elem.className = elemClass;
            }
        parentEl = document.querySelector(parentSelector);
        parentEl.appendChild(elem);
           if(elem instanceof HTMLMediaElement) {
                elem.src = elemSrc;
          }
      }
    }

答案 1 :(得分:0)

我还将在此处添加我的解决方案。该实用程序功能与scatter_update几乎相同,但不使用变量:

def scatter_update_tensor(x, indices, updates):                                                                                                                                                                                                                       
    '''                                                                                                                                                                                                                                                               
    Utility function similar to `tf.scatter_update`, but performing on Tensor                                                                                                                                                                                         
    '''                                                                                                                                                                                                                                                               
    x_shape = tf.shape(x)                                                                                                                                                                                                                                             
    patch = tf.scatter_nd(indices, updates, x_shape)                                                                                                                                                                                                                  
    mask = tf.greater(tf.scatter_nd(indices, tf.ones_like(updates), x_shape), 0)                                                                                                                                                                                      
    return tf.where(mask, patch, x)