我需要创建一个tf.Variable
,其形状只有在执行时才知道。
我将代码简化为以下要点。我需要在占位符中找到大于 4 的数字,并在结果张量中将第二个项目 scatter_update 设置为 24 常量。>
import tensorflow as tf
def get_variable(my_variable):
greater_than = tf.greater(my_variable, tf.constant(4))
result = tf.boolean_mask(my_variable, greater_than)
# result = tf.Variable(tf.zeros(tf.shape(result)), trainable=False, expected_shape=tf.shape(result), validate_shape=False) # doesn't work either
result = tf.get_variable("my_var", shape=tf.shape(my_variable), dtype=tf.int32)
result = tf.scatter_update(result, [1], 24)
return result
input = tf.placeholder(dtype=tf.int32, shape=[5])
created_variable = get_variable(input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
print(result)
答案 0 :(得分:1)
我遇到了同样的问题,偶然遇到了同样的悬而未决的问题,并设法拼凑出了一种在图形创建时创建具有动态形状的变量的解决方案。请注意,必须在tf.Session.run(...)
之前或首次执行时定义形状。
import tensorflow as tf
def get_variable(my_variable):
greater_than = tf.greater(my_variable, tf.constant(4))
result = tf.boolean_mask(my_variable, greater_than)
zerofill = tf.fill(tf.shape(my_variable), tf.constant(0, dtype=tf.int32))
# Initialize
result = tf.get_variable(
"my_var", shape=None, validate_shape=False, dtype=tf.int32, initializer=zerofill
)
result = tf.scatter_update(result, [1], 24)
return result
input = tf.placeholder(dtype=tf.int32, shape=[5])
created_variable = get_variable(input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
print(result)
诀窍是用tf.Variable
,shape=None
创建一个validate_shape=False
并移交形状未知的tf.Tensor
作为初始化程序。