在TensorFlow中更新未知数量的变量

时间:2018-11-09 20:32:20

标签: python tensorflow

我在TensorFlow图中有一个变量集合,我想根据相似的规则同时更新所有变量。一个示例图将是

a = tf.placeholder(tf.int32)
x1 = tf.Variable(0, tf.int32)
x2 = tf.Variable(1, tf.int32)

然后我想一次将数据读入a的一个值,并在每一步将x1x2更新为max(current_value,a)。这可以通过添加两个分配操作来实现

u1 = x1.assign(tf.maximum(a, x1))
u2 = x2.assign(tf.maximum(a, x2))

如果输入数据在列表data和变量vars的列表中,则可以通过循环来完成(为难看的逻辑道歉!)

with tf.Session() as sess:
    for d in data:
        if vars = []:
            continue
        if vars = ['x1']:
            sess.run(u1, {a:d})
        if vars = ['x2']:
            sess.run(u2, {a:d})
        if vars = ['x1', 'x2']:
            sess.run([u1, u2], {a:d})

但是,如果我有大量的xi并且想要避免重复的代码 手动构建每个更新的ui,是否有办法构建一个函数,该函数将变量列表作为参数并为它们生成新的赋值变量?

0 个答案:

没有答案