我是Keras的初学者,因此对于任何普遍的理解不足,我会先道歉。
我想根据存储在另一个张量中的索引手动设置我的Keras张量的某些值。我相信我知道如何使用tf.gather_nd
访问张量的条目(我在下面未经测试的尝试),并且我知道我只能设置变量的值,而不能设置张量。
为清楚起见,这是在GAN的生成和区分阶段之间进行的。
gen_out = generator(inputs)
indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')
batch_size = K.shape(x)[0]
idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))
idx = K.stack((idx_0, indices_to_reset), axis=0)
grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)
# Doesn't work
# gen_out[:,indices_to_reset] = new_values
updated_gen_out = ???
答案 0 :(得分:1)
如果将所有内容转换为one_hot张量并使用switch,则容易得多
(请记住,所有操作都在lambda层内进行,否则您会遇到问题)
def replace_values(x):
outs, indices, values = x
#this is due to a strange bug between lambda and integers....
indices = K.cast(indices, 'int32')
#create one_hot indices
one_hot_indices = K.one_hot(indices, size) #size is the size of gen_out
one_hot_indices = K.batch_flatten(one_hot_indices)
#have the desired values at their correct positions
values_to_use = one_hot_indices * new_values
#if values are 0, use gen_out, else use values
return K.switch(K.equal(values_to_use, 0), outs, values_to_use)
updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])
警告::
new_values
不能为整数,它们必须与gen_out
相同。
import numpy as np
from keras.layers import *
from keras.models import Model
size = 5
batch_size = 15
gen_out = Input((size,))
indices_to_reset = Input((1,), dtype='int32')
new_values = Input((1,))
def replace_values(x):
outs, indices, values = x
print(K.int_shape(outs))
print(K.int_shape(indices))
#this is due to a strange bug between lambda and integers....
indices = K.cast(indices, 'int32')
one_hot_indices = K.one_hot(indices, size)
print(K.int_shape(one_hot_indices))
one_hot_indices = K.batch_flatten(one_hot_indices)
print(K.int_shape(one_hot_indices))
values_to_use = one_hot_indices * new_values
print(K.int_shape(values_to_use))
return K.switch(K.equal(values_to_use, 0), outs, values_to_use)
updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])
model = Model([gen_out,indices_to_reset,new_values], updated_gen_out)
gen_outs = np.arange(batch_size * size).reshape((batch_size, size))
indices = np.concatenate([np.arange(5)]*3, axis=0)
new_vals = np.arange(15).reshape((15,1))
print('\n\ngen outs')
print(gen_outs)
print('\n\nindices')
print(indices)
print('\n\nvalues')
print(new_vals)
print('\n\n results')
print(model.predict([gen_outs, indices, new_vals]))
输出:
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
gen outs
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]
[25 26 27 28 29]
[30 31 32 33 34]
[35 36 37 38 39]
[40 41 42 43 44]
[45 46 47 48 49]
[50 51 52 53 54]
[55 56 57 58 59]
[60 61 62 63 64]
[65 66 67 68 69]
[70 71 72 73 74]]
indices
[0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]
values
[[ 0]
[ 1]
[ 2]
[ 3]
[ 4]
[ 5]
[ 6]
[ 7]
[ 8]
[ 9]
[10]
[11]
[12]
[13]
[14]]
results
[[ 0. 1. 2. 3. 4.]
[ 5. 1. 7. 8. 9.]
[10. 11. 2. 13. 14.]
[15. 16. 17. 3. 19.]
[20. 21. 22. 23. 4.]
[ 5. 26. 27. 28. 29.]
[30. 6. 32. 33. 34.]
[35. 36. 7. 38. 39.]
[40. 41. 42. 8. 44.]
[45. 46. 47. 48. 9.]
[10. 51. 52. 53. 54.]
[55. 11. 57. 58. 59.]
[60. 61. 12. 63. 64.]
[65. 66. 67. 13. 69.]
[70. 71. 72. 73. 14.]]
请注意,将gen_outs
的对角线值替换为new_vals
中的值。
答案 1 :(得分:0)
我现在没有机会尝试,但是您不能使用tf.where
:
updated_gen_out = tf.where(idx_mask, gen_out, new_values)
不过,您首先需要为索引创建一个布尔掩码idx_mask
,并可能重复您的new_value以使其具有与gen_out
相同的形状。