在1D Tensor中更新线性指数的一部分

时间:2017-02-15 17:58:22

标签: python python-2.7 machine-learning tensorflow data-science

我想只更新部分线性索引(在ones_flat中),例如,半或四分之一,如:[0,1,2,3,..... N]改为[0,1, ....,N / 2]

    temp_var = tf.get_variable("W")
    size_2a = tf.get_variable("b")
    s1 = tf.shape(temp_var).eval()[0]
    s2 = tf.shape(size_2a).eval()[0]

    ones_mask = tf.ones([s1,s2])
    indices = tf.slice(ones_mask,[0,0],[s1/2,s2])
    # turn into 1d variable since "scatter_update" supports linear indexing only
    ones_flat = tf.Variable(tf.reshape(ones_mask, [-1]))
    indices_flat = tf.Variable(tf.reshape(indices, [-1]))

    # get linear indices
    linear_indices = tf.random_uniform(tf.shape(indices_flat), dtype=tf.int32, minval=0, maxval =s1*s2)

    # no automatic promotion, so make updates float32 to match ones_mask
    updates = tf.zeros(shape=(tf.shape(linear_indices)), dtype=tf.float32)
    ones_flat_new = tf.scatter_update(ones_flat,linear_indices, updates) 

    # convert back into original shape
    ones_mask_new = tf.reshape(ones_flat_new, ones_mask.get_shape())

    W.assign(tf.mul(ones_mask_new,W))

问题是,在设置为ones_mask中的零1/2并且乘以W后,W似乎没有变化,因为网络的准确性没有变化。并且没有错误,就是这样:

Accuracy_old: 0.9113
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
lin_ind Tensor("foo_1/random_uniform:0", shape=(1960,), dtype=int32)
upd Tensor("foo_1/zeros:0", shape=(3920,), dtype=float32)
W Tensor("foo/W/read:0", shape=(784, 10), dtype=float32)
Accuracy_new: 0.9113

你能帮我在这里找到错误吗?

0 个答案:

没有答案