我遇到了一个问题,可以总结如下:
foo = tf.constant(3)
foo_variable = tf.get_variable("foo", shape=[foo], dtype=tf.int32)
变量的形状必须取决于张量的值(foo
只是对其他操作的计算结果的抽象)
这里的错误是The shape of a variable can not be a Tensor object
如何解决这个问题?
答案 0 :(得分:1)
创建一个由foo
张量指定形状的张量初始化器,然后使用带有validate_shape=False
的该初始化器实例化新变量:
import tensorflow as tf
x = tf.placeholder(tf.int32, shape=())
shape = tf.constant([2, 3]) + x
init = tf.zeros(shape, dtype=tf.int32)
v = tf.get_variable('foo', initializer=init, validate_shape=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), {x: 1})
print(v.eval())
# [[0 0 0 0]
# [0 0 0 0]
# [0 0 0 0]]
答案 1 :(得分:0)
tensorflow变量不能具有动态形状,但是如果您知道会话之外的形状,则可以使用:
foo_variable = tf.get_variable("foo", shape=[], validate_shape=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(foo_variable, feed_dict={foo_variable: ones((2,2))}))