我想只更新部分线性索引(在ones_flat中),例如,半或四分之一,如:[0,1,2,3,..... N]改为[0,1, ....,N / 2]
temp_var = tf.get_variable("W")
size_2a = tf.get_variable("b")
s1 = tf.shape(temp_var).eval()[0]
s2 = tf.shape(size_2a).eval()[0]
ones_mask = tf.ones([s1,s2])
indices = tf.slice(ones_mask,[0,0],[s1/2,s2])
# turn into 1d variable since "scatter_update" supports linear indexing only
ones_flat = tf.Variable(tf.reshape(ones_mask, [-1]))
indices_flat = tf.Variable(tf.reshape(indices, [-1]))
# get linear indices
linear_indices = tf.random_uniform(tf.shape(indices_flat), dtype=tf.int32, minval=0, maxval =s1*s2)
# no automatic promotion, so make updates float32 to match ones_mask
updates = tf.zeros(shape=(tf.shape(linear_indices)), dtype=tf.float32)
ones_flat_new = tf.scatter_update(ones_flat,linear_indices, updates)
# convert back into original shape
ones_mask_new = tf.reshape(ones_flat_new, ones_mask.get_shape())
W.assign(tf.mul(ones_mask_new,W))
问题是,在设置为ones_mask中的零1/2并且乘以W后,W似乎没有变化,因为网络的准确性没有变化。并且没有错误,就是这样:
Accuracy_old: 0.9113
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
lin_ind Tensor("foo_1/random_uniform:0", shape=(1960,), dtype=int32)
upd Tensor("foo_1/zeros:0", shape=(3920,), dtype=float32)
W Tensor("foo/W/read:0", shape=(784, 10), dtype=float32)
Accuracy_new: 0.9113
你能帮我在这里找到错误吗?