我的TensorFlow模型使用tf.random_uniform
来初始化变量。我想在开始训练时指定范围,因此我为初始化值创建了一个占位符。
init = tf.placeholder(tf.float32, name="init")
v = tf.Variable(tf.random_uniform((100, 300), -init, init), dtype=tf.float32)
initialize = tf.initialize_all_variables()
我在训练开始时初始化变量就像这样。
session.run(initialize, feed_dict={init: 0.5})
这给了我以下错误:
ValueError: initial_value must have a shape specified: Tensor("Embedding/random_uniform:0", dtype=float32)
我无法确定要传递给shape
的正确tf.placeholder
参数。我认为标量我应该init = tf.placeholder(tf.float32, shape=0, name="init")
,但这会产生以下错误:
ValueError: Incompatible shapes for broadcasting: (100, 300) and (0,)
如果我在init
的调用中将0.5
替换为字面值tf.random_uniform
,则可以使用。
如何通过Feed字典传递此标量初始值?
答案 0 :(得分:29)
TL; DR:使用标量形状定义init
,如下所示:
init = tf.placeholder(tf.float32, shape=(), name="init")
这看起来像是tf.random_uniform()
的一个不幸的实现细节:它目前使用tf.add()
和tf.multiply()
将随机值从[-1,+ 1]重新调整为[minval
},maxval
],但如果minval
或maxval
的形状未知,tf.add()
和tf.multiply()
无法推断出正确的形状,因为可能涉及广播。
通过定义具有已知形状的init
(标量为()
或[]
,而不是0
),TensorFlow可以得出关于形状的正确推断tf.random_uniform()
的结果,您的程序应按预期工作。
答案 1 :(得分:0)
您不需要占位符来传递标量,因为任何张量,sparsetensor或紧随其后的元组或sparsetensor元组都可以。 doc表示:
The optional `feed_dict` argument allows the caller to override the value of tensors in the graph. Each key in `feed_dict` can be one of the following types: * If the key is a `tf.Tensor`, the value may be a Python scalar, string, list, or numpy ndarray that can be converted to the same `dtype` as that tensor. Additionally, if the key is a `tf.placeholder`, the shape of the value will be checked for compatibility with the placeholder. * If the key is a `tf.SparseTensor`, the value should be a `tf.SparseTensorValue`. * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value should be a nested tuple with the same structure that maps to their corresponding values as above.
在您的情况下,任何张量(例如常量或变量或占位符)都可能是合适的。
init = tf.constant(0)
init_1 = tf.Variable(0)
v = tf.Variable(tf.random_uniform((100, 300), -init, init), dtype=tf.float32)
initialize = tf.global_variables_initializer()
sess.run(intialize, feed_dict={init: 0.5})
sess.run(intialize, feed_dict={init_1: 0.5})
您可以将float或int传递给它,因为只有占位符会如上所述检查数据类型。