我有两个这样的张量:
template = tf.convert_to_tensor([[1, 0, 0.5, 0.5, 0.3, 0.3],
[1, 0, 0.75, 0.5, 0.3, 0.3],
[1, 0, 0.5, 0.75, 0.3, 0.3],
[1, 0, 0.75, 0.75, 0.3, 0.3]])
patch = tf.convert_to_tensor([[0, 1, 0.43, 0.17, 0.4, 0.4],
[0, 1, 0.18, 0.22, 0.53, 0.6]])
现在,我想用template
行来更新patch
的第二行和最后一行,以得到像这样的值:
[[1. 0. 0.5 0.5 0.3 0.3 ]
[0. 1. 0.43 0.17 0.4 0.4 ]
[1. 0. 0.5 0.75 0.3 0.3 ]
[0. 1. 0.18 0.22 0.53 0.6 ]]
使用tf.scatter_update
很容易:
var_template = tf.Variable(template)
var_template = tf.scatter_update(var_template, [1, 3], patch)
但是,它需要创建一个变量。有没有一种方法可以仅使用张量操作来获取值?
我当时在考虑tf.where
,但是随后我可能不得不将每个补丁行广播到模板大小中,并为每一行调用tf.where
。
答案 0 :(得分:1)
这应该有效。有点扭曲,但未使用任何变量。
function insertElem(numberOfElems, elemTag, elemId, elemClass, parentSelector, elemSrc){
/*
* numberOfElements:- Pass in a whole integer.
* elemTag:- Pass in the element tag type (inside "" or '').
* elemId:- Pass in a name for the element id (inside "" or ''),
an integer is appended to the id name by the for loop.
* elemClass:- Pass in a name for element class (inside "" or '').
* parentSelector:- Pass in the identifier of the parent element (inside "" or '')
* querySelector prefixes: # = id
. = class
none = tag
*elemSrc:- Pass in the source media url (inside "" or '').
*/
if (numberOfElems > 1) {
for (i = 0; i < numberOfElems; i++) {
var elem = this[elemId + i];
elem = document.createElement(elemTag);
elem.id = elemId + '_' + i;
if (elemClass) {
elem.className = elemClass;
}
parentEl = document.querySelector(parentSelector);
parentEl.appendChild(elem);
if(elem instanceof HTMLMediaElement) {
elem.src = elemSrc;
}
}
} else {
var elem = this[elemId];
elem = document.createElement(elemTag);
elem.id = elemId;
if (elemClass) {
elem.className = elemClass;
}
parentEl = document.querySelector(parentSelector);
parentEl.appendChild(elem);
if(elem instanceof HTMLMediaElement) {
elem.src = elemSrc;
}
}
}
答案 1 :(得分:0)
我还将在此处添加我的解决方案。该实用程序功能与scatter_update
几乎相同,但不使用变量:
def scatter_update_tensor(x, indices, updates):
'''
Utility function similar to `tf.scatter_update`, but performing on Tensor
'''
x_shape = tf.shape(x)
patch = tf.scatter_nd(indices, updates, x_shape)
mask = tf.greater(tf.scatter_nd(indices, tf.ones_like(updates), x_shape), 0)
return tf.where(mask, patch, x)