将scatter_update与tensorflow中的馈送数据一起使用

时间:2019-01-09 12:23:02

标签: python tensorflow

我正在尝试使用scatter_update更新张量的切片。我的第一个熟悉该功能的代码段非常完美。

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    init_val = tf.Variable(tf.zeros((3, 2)))
    indices = tf.constant([0, 1])
    update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update))

但是当我尝试将初始值输入到像这样的图中

with tf.Session() as sess:
    x = tf.placeholder(tf.float32, shape=(3, 2))
    init_val = x
    indices = tf.constant([0, 1])
    update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update, feed_dict={x: np.zeros((3, 2))}))

我收到奇怪的错误

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [3,2]
 [[{{node Placeholder_1}} = Placeholder[dtype=DT_FLOAT, shape=[3,2], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

tf.Variable分配到x时将init_val拖放到AttributeError: 'Tensor' object has no attribute '_lazy_read' 上也无济于事,因为我遇到了错误

'use strict';

function createIt(data) {
  let res = [];

  for (const i of data) {
    res.push({
      fruit: i
    });
  }

  return res;
}

var list = ["orange", "apple", "pineapple"];

console.log(createIt(list));

(请参阅Github上的this entry)。有人知道吗?预先感谢!

我在CPU上使用Tensorflow 1.12。

1 个答案:

答案 0 :(得分:1)

您可以通过构建和更新张量和蒙版张量来通过散射替换张量:

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    x = tf.placeholder(tf.float32, shape=(3, 2))
    init_val = x
    indices = tf.constant([0, 1])
    x_shape = tf.shape(x)
    indices = tf.expand_dims(indices, 1)
    replacement = tf.ones((2, 2))
    update = tf.scatter_nd(indices, replacement, x_shape)
    mask = tf.scatter_nd(indices, tf.ones_like(replacement, dtype=tf.bool), x_shape)
    result = tf.where(mask, update, x)
    print(sess.run(result, feed_dict={x: np.arange(6).reshape((3, 2))}))

输出:

[[1. 1.]
 [1. 1.]
 [4. 5.]]