假设我想要采样矩阵,其中每个条目都是从另一个矩阵中的条目定义的分布中采样的。我展开我的矩阵并将map_fn应用于每个元素。使用相对较小的矩阵(128 x 128),以下内容为我提供了几个PoolAllocator警告(GTX TITAN Black),并且不会在任何合理的时间内进行训练。
def sample(x):
samples = tf.map_fn(lambda z:
tf.random_normal([1], mean=z,
stddev=tf.sqrt(z * (1 - z))),
tf.reshape(x, [-1])) # apply to each element
return tf.cond(is_training, lambda: tf.reshape(samples, shape=tf.shape(x)),
lambda: tf.tanh(x))
是否有更好的方法来应用这样的元素运算?
答案 0 :(得分:2)
如果您可以使用Tensor-at-a-time操作而不是像tf.map_fn这样的元素操作,那么您的代码将运行得更快。
这里看起来你想从每个元素的正态分布中进行采样,其中分布的参数对于输入张量中的每个值是不同的。尝试这样的事情:
def sample(x):
samples = tf.random_normal(shape=[128, 128]) * tf.sqrt(x * (1 - x)) + x
tf.random_normal()默认生成平均值为0.0且标准差为1.0的正态分布。您可以使用逐点张量操作来确定每个元素的标准偏差(乘以)和平均值(通过相加)。事实上,如果你看一下tf.random_normal()是如何实现的,那正是它在内部的作用。
(您可能也会使用Python条件来区分训练和测试时间。)
如果你计划做很多这样的事情,你可以在github上提交一个功能请求,要求概括tf.random_normal以接受具有mean
和stddev
更一般形状的张量。我认为没有理由不支持这一点。
希望有所帮助!
答案 1 :(得分:0)
请参阅tensorflow.contrib.distributions
模块,其中Normal
类具有sample
方法,可以为您执行此操作。