元素采样与map_fn慢

时间:2016-07-01 02:16:21

标签: numpy matrix tensorflow

假设我想要采样矩阵,其中每个条目都是从另一个矩阵中的条目定义的分布中采样的。我展开我的矩阵并将map_fn应用于每个元素。使用相对较小的矩阵(128 x 128),以下内容为我提供了几个PoolAllocator警告(GTX TITAN Black),并且不会在任何合理的时间内进行训练。

def sample(x):
   samples = tf.map_fn(lambda z:
                      tf.random_normal([1], mean=z,
                                       stddev=tf.sqrt(z * (1 - z))),
                      tf.reshape(x, [-1]))    # apply to each element

   return tf.cond(is_training, lambda: tf.reshape(samples, shape=tf.shape(x)),
                  lambda: tf.tanh(x))

是否有更好的方法来应用这样的元素运算?

2 个答案:

答案 0 :(得分:2)

如果您可以使用Tensor-at-a-time操作而不是像tf.map_fn这样的元素操作,那么您的代码将运行得更快。

这里看起来你想从每个元素的正态分布中进行采样,其中分布的参数对于输入张量中的每个值是不同的。尝试这样的事情:

def sample(x):
  samples = tf.random_normal(shape=[128, 128]) * tf.sqrt(x * (1 - x)) + x

tf.random_normal()默认生成平均值为0.0且标准差为1.0的正态分布。您可以使用逐点张量操作来确定每个元素的标准偏差(乘以)和平均值(通过相加)。事实上,如果你看一下tf.random_normal()是如何实现的,那正是它在内部的作用。

(您可能也会使用Python条件来区分训练和测试时间。)

如果你计划做很多这样的事情,你可以在github上提交一个功能请求,要求概括tf.random_normal以接受具有meanstddev更一般形状的张量。我认为没有理由不支持这一点。

希望有所帮助!

答案 1 :(得分:0)

请参阅tensorflow.contrib.distributions模块,其中Normal类具有sample方法,可以为您执行此操作。