如何在Tensorflow中更新2D张量的子集?

时间:2016-10-04 18:41:12

标签: python neural-network tensorflow deep-learning

我想更新值为0的2D张量中的索引。因此,数据是2D张量,其第2行第2列索引值将替换为0.但是,我收到类型错误。任何人都可以帮我吗?

  

TypeError:输入' ref' ' ScatterUpdate'操作需要l值输入

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])

2 个答案:

答案 0 :(得分:5)

分散更新仅适用于变量。而是尝试这种模式。

Tensorflow版本< 1.0: a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])

Tensorflow版本> = 1.0: a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])

答案 1 :(得分:3)

Variable只能应用于data类型。您的代码中的Variabledata2,而tf.reshape则不是,因为Tensor的返回类型是data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]]) row = tf.gather(data, 2) new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0) sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

解决方案:

用于v1.0之后的张量流

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

for vor.0之前的张量流

if let day = timeDifference.day, minute = timeDifference.minute, second = timeDifference.second {
  WeeklyDateLabel.text = "\(timeDifference.day) Days \(timeDifference.minute) Minutes \(timeDifference.second) Seconds"
}