分配Keras张量的索引条目

时间:2019-02-15 22:05:42

标签: python keras

我是Keras的初学者,因此对于任何普遍的理解不足,我会先道歉。

我想根据存储在另一个张量中的索引手动设置我的Keras张量的某些值。我相信我知道如何使用tf.gather_nd访问张量的条目(我在下面未经测试的尝试),并且我知道我只能设置变量的值,而不能设置张量。

为清楚起见,这是在GAN的生成和区分阶段之间进行的。

gen_out = generator(inputs)

indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')

batch_size = K.shape(x)[0]

idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))

idx = K.stack((idx_0, indices_to_reset), axis=0)

grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)

# Doesn't work
# gen_out[:,indices_to_reset] = new_values

updated_gen_out = ???

2 个答案:

答案 0 :(得分:1)

如果将所有内容转换为one_hot张量并使用switch,则容易得多

(请记住,所有操作都在lambda层内进行,否则您会遇到问题)

def replace_values(x):
    outs, indices, values = x

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')


    #create one_hot indices
    one_hot_indices = K.one_hot(indices, size) #size is the size of gen_out
    one_hot_indices = K.batch_flatten(one_hot_indices)

    #have the desired values at their correct positions
    values_to_use = one_hot_indices * new_values


    #if values are 0, use gen_out, else use values
    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)


updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])
  

警告:new_values不能为整数,它们必须与gen_out相同。


虚拟示例:

import numpy as np
from keras.layers import *
from keras.models import Model

size = 5
batch_size = 15

gen_out = Input((size,))
indices_to_reset = Input((1,), dtype='int32')
new_values = Input((1,))

def replace_values(x):
    outs, indices, values = x
    print(K.int_shape(outs))
    print(K.int_shape(indices))

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')
    one_hot_indices = K.one_hot(indices, size)
    print(K.int_shape(one_hot_indices))
    one_hot_indices = K.batch_flatten(one_hot_indices)
    print(K.int_shape(one_hot_indices))

    values_to_use = one_hot_indices * new_values
    print(K.int_shape(values_to_use))

    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)

updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])

model = Model([gen_out,indices_to_reset,new_values], updated_gen_out)

gen_outs = np.arange(batch_size * size).reshape((batch_size, size))
indices = np.concatenate([np.arange(5)]*3, axis=0)
new_vals = np.arange(15).reshape((15,1))

print('\n\ngen outs')
print(gen_outs)

print('\n\nindices')
print(indices)

print('\n\nvalues')
print(new_vals)

print('\n\n results')
print(model.predict([gen_outs, indices, new_vals]))

输出:

(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)


gen outs
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]
 [30 31 32 33 34]
 [35 36 37 38 39]
 [40 41 42 43 44]
 [45 46 47 48 49]
 [50 51 52 53 54]
 [55 56 57 58 59]
 [60 61 62 63 64]
 [65 66 67 68 69]
 [70 71 72 73 74]]


indices
[0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]


values
[[ 0]
 [ 1]
 [ 2]
 [ 3]
 [ 4]
 [ 5]
 [ 6]
 [ 7]
 [ 8]
 [ 9]
 [10]
 [11]
 [12]
 [13]
 [14]]


 results
[[ 0.  1.  2.  3.  4.]
 [ 5.  1.  7.  8.  9.]
 [10. 11.  2. 13. 14.]
 [15. 16. 17.  3. 19.]
 [20. 21. 22. 23.  4.]
 [ 5. 26. 27. 28. 29.]
 [30.  6. 32. 33. 34.]
 [35. 36.  7. 38. 39.]
 [40. 41. 42.  8. 44.]
 [45. 46. 47. 48.  9.]
 [10. 51. 52. 53. 54.]
 [55. 11. 57. 58. 59.]
 [60. 61. 12. 63. 64.]
 [65. 66. 67. 13. 69.]
 [70. 71. 72. 73. 14.]] 

请注意,将gen_outs的对角线值替换为new_vals中的值。

答案 1 :(得分:0)

我现在没有机会尝试,但是您不能使用tf.where

updated_gen_out = tf.where(idx_mask, gen_out, new_values)

不过,您首先需要为索引创建一个布尔掩码idx_mask,并可能重复您的new_value以使其具有与gen_out相同的形状。