Tensorflow:如何根据条件从多个变量中一次仅更新单个变量

时间:2018-06-09 05:14:09

标签: tensorflow

k1 = tf.Variable(10.0)
k2 = tf.Variable(10.0)

pred = tf.pow(B, ?) / C
cost = tf.pow(pred_s1 - Y, 2)
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)

sess.run(optimizer, feed_dict{A:a, B:b, C:c})

更新

pred = tf.pow(B, k1) / C if A == 0
pred = tf.pow(B, k2) / C if A == 1

单个预测函数,它根据输入占位符'A'

的值仅更新一个变量

2 个答案:

答案 0 :(得分:0)

s1 = tf.Variable(tf.random_normal([1]))
s2 = tf.Variable(tf.random_normal([1]))
s3 = tf.Variable(tf.random_normal([1]))
s4 = tf.Variable(tf.random_normal([1]))
s5 = tf.Variable(tf.random_normal([1]))

D = tf.placeholder("float")

s2_s = tf.where(tf.logical_and(1.9<D,D<2.1),x=s2,y=s1)
s3_s = tf.where(tf.logical_and(2.9<D,D<3.1),x=s3,y=s2_s)
s4_s = tf.where(tf.logical_and(3.9<D,D<4.1),x=s4,y=s3_s)
s5_s = tf.where(tf.logical_and(4.9<D,D<5.1),x=s5,y=s4_s)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([s1])[0], sess.run([s2])[0], sess.run([s3])[0], sess.run([s4])[0], sess.run([s5])[0])

print(sess.run(s5_s, feed_dict={D:5}))

sess.close()

答案 1 :(得分:-1)

只需使用

pred = tf.pow(B, A*k2 + (1-A)* k1) / C

给出了开关。另一种选择是tf.where