在tensorflow中动态更改tf.truncated_normal的stddev

时间:2017-11-08 12:17:20

标签: python random tensorflow

我正在构建一个图表,我希望初始权重具有可变的标准偏差。 我试图使用以下命令,但它产生了一个错误:

import tensorflow as tf
import numpy as np
stddev = tf.placeholder(dtype=tf.float32)
a = tf.placeholder(dtype=tf.float32, shape=[1,50])
weight1 = tf.Variable(tf.truncated_normal(shape=[50, 30],stddev=stddev))
result = tf.reduce_sum(tf.matmul(a, weight1))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(result , {a: np.random.randn(1, 50), stddev: 0.01}))

任何人都可以帮我解决这个问题吗? 我知道在定义时我可以设置stddev,但我面临的任务是在培训过程中使用变体stddev

1 个答案:

答案 0 :(得分:1)

像这样使用tf.placeholder_with_default

import numpy as np
import tensorflow as tf

stddev = tf.placeholder_with_default(0.1, shape=(), name='stddev')
weight1 = tf.Variable(tf.truncated_normal(shape=[50, 30], stddev=stddev))

a = tf.placeholder(dtype=tf.float32, shape=[1,50])
result = tf.reduce_sum(tf.matmul(a, weight1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(result , feed_dict={a: np.random.randn(1, 50), stddev: 0.01}))