如何将张量类型值传递给Variable的shape参数?

时间:2019-06-22 09:55:36

标签: python tensorflow

我遇到了一个问题,可以总结如下:

foo = tf.constant(3)
foo_variable = tf.get_variable("foo", shape=[foo], dtype=tf.int32)

变量的形状必须取决于张量的值(foo只是对其他操作的计算结果的抽象)

这里的错误是The shape of a variable can not be a Tensor object

如何解决这个问题?

2 个答案:

答案 0 :(得分:1)

创建一个由foo张量指定形状的张量初始化器,然后使用带有validate_shape=False的该初始化器实例化新变量:

import tensorflow as tf

x = tf.placeholder(tf.int32, shape=())
shape = tf.constant([2, 3]) + x
init = tf.zeros(shape, dtype=tf.int32)
v = tf.get_variable('foo', initializer=init, validate_shape=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), {x: 1})
    print(v.eval())
    # [[0 0 0 0]
    #  [0 0 0 0]
    #  [0 0 0 0]]

答案 1 :(得分:0)

tensorflow变量不能具有动态形状,但是如果您知道会话之外的形状,则可以使用:

foo_variable = tf.get_variable("foo", shape=[], validate_shape=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(foo_variable, feed_dict={foo_variable: ones((2,2))}))