如何通过TensorFlow提要字典传递标量

时间:2016-08-23 23:24:21

标签: python machine-learning tensorflow

我的TensorFlow模型使用tf.random_uniform来初始化变量。我想在开始训练时指定范围,因此我为初始化值创建了一个占位符。

init = tf.placeholder(tf.float32, name="init")
v = tf.Variable(tf.random_uniform((100, 300), -init, init), dtype=tf.float32)
initialize = tf.initialize_all_variables()

我在训练开始时初始化变量就像这样。

session.run(initialize, feed_dict={init: 0.5})

这给了我以下错误:

ValueError: initial_value must have a shape specified: Tensor("Embedding/random_uniform:0", dtype=float32)

我无法确定要传递给shape的正确tf.placeholder参数。我认为标量我应该init = tf.placeholder(tf.float32, shape=0, name="init"),但这会产生以下错误:

ValueError: Incompatible shapes for broadcasting: (100, 300) and (0,)

如果我在init的调用中将0.5替换为字面值tf.random_uniform,则可以使用。

如何通过Feed字典传递此标量初始值?

2 个答案:

答案 0 :(得分:29)

TL; DR:使用标量形状定义init,如下所示:

init = tf.placeholder(tf.float32, shape=(), name="init")

这看起来像是tf.random_uniform()的一个不幸的实现细节:它目前使用tf.add()tf.multiply()将随机值从[-1,+ 1]重新调整为[minval },maxval],但如果minvalmaxval的形状未知,tf.add()tf.multiply()无法推断出正确的形状,因为可能涉及广播。

通过定义具有已知形状的init(标量为()[],而不是0),TensorFlow可以得出关于形状的正确推断tf.random_uniform()的结果,您的程序应按预期工作。

答案 1 :(得分:0)

您不需要占位符来传递标量,因为任何张量,sparsetensor或紧随其后的元组或sparsetensor元组都可以。 doc表示:

The optional `feed_dict` argument allows the caller to override
the value of tensors in the graph. Each key in `feed_dict` can be
one of the following types:
* If the key is a `tf.Tensor`, the
  value may be a Python scalar, string, list, or numpy ndarray
  that can be converted to the same `dtype` as that
  tensor. Additionally, if the key is a
  `tf.placeholder`, the shape of
  the value will be checked for compatibility with the placeholder.
* If the key is a
  `tf.SparseTensor`,
  the value should be a
  `tf.SparseTensorValue`.
* If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
  should be a nested tuple with the same structure that maps to their
  corresponding values as above.

在您的情况下,任何张量(例如常量或变量或占位符)都可能是合适的。

init = tf.constant(0)
init_1 = tf.Variable(0)
v = tf.Variable(tf.random_uniform((100, 300), -init, init), dtype=tf.float32)
initialize = tf.global_variables_initializer()
sess.run(intialize, feed_dict={init: 0.5})
sess.run(intialize, feed_dict={init_1: 0.5})

您可以将float或int传递给它,因为只有占位符会如上所述检查数据类型。