计算线性指数Tensorflow

时间:2017-02-10 00:49:28

标签: python numpy machine-learning tensorflow data-science

下午好。 我继续遇到通过索引更新张量流中的随机元素的问题。 我想随机选择索引(例如,一半),然后设置为零元素对应于那些索引。 这是有问题的部分:

with tf.variable_scope("foo", reuse=True):
    temp_var = tf.get_variable("W")
    size_2a = tf.get_variable("b")
    s1 = tf.shape(temp_var).eval()[0]
    s2 = tf.shape(size_2a).eval()[0]


    row_indices = tf.random_uniform(dtype=tf.int32, minval=0, maxval = s1 - 1, shape=[s1]).eval()
    col_indices = tf.random_uniform(dtype=tf.int32, minval=0, maxval = s2 - 1, shape=[s2]).eval()

    ones_mask = tf.ones([s1,s2])

    # turn 'ones_mask' into 1d variable since "scatter_update" supports linear indexing only
    ones_flat = tf.Variable(tf.reshape(ones_mask, [-1]))

    # no automatic promotion, so make updates float32 to match ones_mask
    updates = tf.zeros(shape=(s1,), dtype=tf.float32)


    # get linear indices
    linear_indices = row_indices*s2 + tf.reshape(col_indices,s1*s2)
    ones_flat = tf.scatter_update(ones_flat, linear_indices/2, updates) 
    #I want to set to zero only half of all elements,that's why linear_indices/2

    # convert back into original shape
    ones_mask = tf.reshape(ones_flat, ones_mask.get_shape())

它给了我ValueError:无法重塑具有10个元素的张量来塑造[784,10](7840个元素)以用于' foo_1 / Reshape_1' (op:' Reshape')输入形状:[10],[2]。但是我不知道如何在没有重塑的情况下来到这里(我试图重塑s1和s2,没用)

我已经阅读了这些主题:Update values of a matrix variable in tensorflow, advanced indexing(feed_dict似乎不适用于我的情况),python numpy ValueError: operands could not be broadcast together with shapes几乎所有关于stackoverflow上主题的内容=(

0 个答案:

没有答案