Tensorflow更新每行中的第一个匹配元素

时间:2017-08-14 23:24:34

标签: python tensorflow

在此question的基础上构建我希望在连续第一次满足tf.where条件时更新二维张量的值。这是我用来模拟的示例代码:

tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
    val = "hello"
    new_val = "goodbye"
    matrix = tf.constant([["word","hello","hello"],
                          ["word", "other", "hello"],
                          ["hello", "hello","hello"],
                          ["word", "word", "word"]
                         ])
    matching_indices = tf.where(tf.equal(matrix, val))
    first_matching_idx = tf.segment_min(data = matching_indices[:, 1],
                                 segment_ids = matching_indices[:, 0])

sess = tf.InteractiveSession(graph=graph)
print(sess.run(first_matching_idx))

这将输出[1,2,0],其中1是第1行中第一个hello的位置,2是第2行中第一个hello的位置,0是第一个hello的位置在第3行。

但是,我无法找到一种方法来使用新值更新第一个匹配的索引 - 基本上我想要第一个"你好"变成"再见"。我尝试过使用tf.scatter_update()但它似乎不适用于2D张量。有没有办法按照描述修改二维张量?

1 个答案:

答案 0 :(得分:0)

一个简单的解决方法是将tf.py_func与numpy数组一起使用

def ch_val(array, val, new_val):
    idx = np.array([[s, list(row).index(val)]
                for s, row in enumerate(array) if val in row])
    idx = tuple((idx[:, 0], idx[:, 1]))
    array[idx] = new_val
    return array

...
matrix = tf.Variable([["word","hello","hello"],
                      ["word", "other", "hello"],
                      ["hello", "hello","hello"],
                      ["word", "word", "word"]
                     ])
matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string) 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(matrix))
    # results: [['word' 'goodbye' 'hello']
      ['word' 'other' 'goodbye']
      ['goodbye' 'hello' 'hello']
      ['word' 'word' 'word']]