我有一个函数,它将(图像,隐藏状态)作为输入,并将输出反馈回输入,这取决于序列的长度,最后一个输出是我感兴趣的: -
重复功能如下: -
def recurrence():
input_image = Input(shape=[dim1,dim2,dim3])
hidden_state= Input(shape=[4,4,4,128])
conv1a = Conv2D(filters=96, kernel_size=7, padding='same',input_shape=(dim1, dim2,dim3), kernel_initializer=k_init, bias_initializer=b_init)(input_image)
conv1a = LeakyReLU(0.01)(conv1a)
conv1b = Conv2D(filters=96, kernel_size=3, padding='same', kernel_initializer=k_init, bias_initializer=b_init)(conv1a)
conv1b = LeakyReLU(0.01)(conv1b)
conv1b = ZeroPadding2D(padding=(1, 1))(conv1b)
pool1 = MaxPooling2D(2)(conv1b)
flat6 = Flatten()(pool1)
fc7 = Dense(units=1024, kernel_initializer=k_init, bias_initializer=b_init)(flat6)
rect7 = LeakyReLU(0.01)(fc7)
t_x_s_update_conv = Conv3D(128, 3, activation=None, padding='same', kernel_initializer=k_init, use_bias=False)(hidden_state)
t_x_s_update_dense =Reshape((4,4,4,128))(Dense(units=8192)(rect7))
t_x_s_update = layers.add([t_x_s_update_conv, t_x_s_update_dense])
t_x_s_reset_conv = Conv3D(128, 3, activation=None, padding='same', kernel_initializer=k_init, use_bias=False)(hidden_state)
t_x_s_reset_dense =Reshape((4,4,4,128))(Dense(units=8192)(rect7))
t_x_s_reset =layers.add([t_x_s_reset_conv, t_x_s_reset_dense])
update_gate = Activation(K.sigmoid)(t_x_s_update)
comp_update_gate = Lambda(lambda x: 1 - x)(update_gate)
reset_gate = Activation(K.sigmoid)(t_x_s_reset)
rs = layers.multiply([reset_gate, hidden_state])
t_x_rs_conv = Conv3D(128, 3, activation=None, padding='same', kernel_initializer=k_init, use_bias=False)(rs)
t_x_rs_dense = Reshape((4,4,4,128))(Dense(units=8192)(rect7))
t_x_rs = layers.add([t_x_rs_conv, t_x_rs_dense])
tanh_t_x_rs = Activation(K.tanh)(t_x_rs)
gru_out = layers.add([layers.multiply([update_gate, hidden_state]), layers.multiply([comp_update_gate, tanh_t_x_rs])])
return Model(inputs=[input_image,hidden_state],outputs=gru_out)
目前我正在使用循环并实例化一次模型,然后重复使用它,如下所示: -
step=recurrence()
s_update = hidden_init
for i in input_imgs:
s_update = step([i, s_update])`
但是这种方法似乎有以下两个缺点:
i)考虑到所有变量都是共享的,它消耗了大量的GPU内存。
ii)在没有填充的情况下,它不适用于任意长度的序列。
有没有一种更好更有效的方法来实现这一点,我尝试在recurrent.py文件中阅读SimpleRNN的代码,但是我不知道如何将所有这些层(如Convolution3d)合并到该框架中? 非常感谢虚拟代码或您可以提供的任何帮助。谢谢!