我正在尝试设计一个仅适用于输入张量子集的Tensorflow函数。在内部,我认为实现这一目标的最好方法是创建一个bool掩码,指示张量的哪些条目是"活跃的"哪个应该被忽略。
我已经有一个看起来像这样的numpy实现
def func( data, maxIterations ):
ndata = len( data )
ind = numpy.arange( ndata )
mask = numpy.ones( (ndata,), dtype=tf.bool )
for iteration in range( maxIterations ):
# Determine which part of the input tensor is "active"
# Note that the "for" loop will be vectorized in practice
mask.fill(False)
for point in data:
if condition(point):
mask[point] = True
# Now do something with the active part of the tensor
doSomething( data, mask )
现在我想将其转换为Tensorflow实现。问题似乎是我无法弄清楚如何创建一个不会被优化例程更新的可变掩码。例如,tf.ones
函数似乎在创建一个不可变的数组。我怎么能让它变得可变?
def func_tf( data, maxIterations ):
ndata = tf.shape(data)[0]
ind = tf.range( ndata )
mask = tf.ones( (ndata,), dtype=bool )
for iteration in range( maxIterations ):
# Determine which part of the input tensor is "active"
# Note that the "for" loop will be vectorized in practice
tf.assign(mask, False)
for point in data:
if condition(point):
mask[point] = True
# Now do something with the active part of the tensor
doSomething( data, mask )
我应该补充说,将数组创建为numpy数组然后用tf.Variable
包装似乎非常低效,因为它必须将内容从CPU传输到GPU。
或者,如果你有一个替代的设计模式,你通常用于这些类型的问题,我很乐意在这里!