循环计算图

时间:2017-07-17 15:22:40

标签: tensorflow

我想创建一个循环计算图。这个想法很简单,详情如下:

  • 初始化网络的权重。
  • 从多元高斯中抽取N个权重,其中初始权重是高斯的平均值。
  • 评估每组权重的一些损失函数。
  • 适当更新权重。

基本方法的图像可以看作如下:

enter image description here

我目前的方法是在循环训练期间采样和更新权重。然而,这很慢,我想知道我是否可以将此功能构建到计算图中并加快我的训练速度。

1 个答案:

答案 0 :(得分:0)

您应该能够在计算图中完成所有操作。例如,权重变量W

NUM_SAMPLES = 10
STDDEV = 1

# Assuming W statically shaped, otherwise you'd use tf.shape and tf.concat
samples_shape = [0] + W.shape.as_list()
# Generate random numbers with W as mean
samples = tf.random_normal(samples_shape,
                           stddev=tf.constant(STDDEV, dtype=W.dtype),
                           dtype=W.dtype)
samples += W[tf.newaxis, :]
# The loss function should return a vector the size of
# the first dimension of samples
samples_loss = loss(samples)
idx = tf.argmin(samples_loss, axis=0)
# Update W
update_op = tf.assign(W, samples[idx])

然后你运行update_op来执行一个更新步骤,或者继续使用它作为控件依赖项的其他操作:

with tf.control_dependencies([update_op]):
    # More ops...