keras中的TimeDistributed包装器如何工作?

时间:2018-09-13 16:10:12

标签: python keras rnn

根据keras documentation

  

此包装器将一层应用于输入的每个时间片。

所以我的解释是用伪代码编写的,它看起来像这样:如果我有一个形状为 [batchsize,timesteps,width,height,channel]的数组 input < / em>和功能 conv 的TimeDistributed包装器将首先将输入转换为 [时间步长,批大小,宽度,高度,通道] 的形状然后:

l = []
for elem in input:
    l.append(conv(elem))
return l

给出以下示例,其中 k 是内核矩阵, b 是偏差矩阵,而 i 是输入序列:

import os
import numpy as np
from keras.layers.convolutional import Conv2D
from keras.models import Model
from keras.layers import Input, TimeDistributed

i = np.random.rand(1,16,92,92,128)
k = np.random.rand(3,3,128,256)
b = np.random.rand(256,)

input1 = Input(shape=(None, None, None, 128))
conv1 = TimeDistributed(Conv2D(256, (3, 3), activation='relu', padding="same"), name="conv")(input1)
model1 = Model(inputs=input1, outputs=conv1)
model1.get_layer("conv").layer.set_weights([k,b])
model1.compile(loss='mean_squared_error',
              optimizer='adam',
              metrics=['accuracy'])
res1 = model1.predict(i)

input2 = Input(shape=(None, None, 128))
conv2 = Conv2D(256, (3, 3), activation='relu', padding="same", name="conv")(input2)
model2 = Model(inputs=input2, outputs=conv2)
model2.get_layer("conv").set_weights([k,b])
model2.compile(loss='mean_squared_error',
              optimizer='adam',
              metrics=['accuracy'])
res2 = model2.predict(np.asarray([i[0,-1]]))

我希望:

np.all(res1[0,-1] == res2[0])

为True,但这不是我的测试用例。为什么?

编辑

经过更多的运行,我才意识到,这似乎是一个随机因素:

有时np.all(res1[0,-1] == res2[0])为True,有时为False。

我遇到的另一个问题:

我正在使用的权重来自具有相同TimeDistributed(Conv2D(256, (3, 3), activation='relu', padding="same"), name="conv")层的另一个网络。所以当我服用:

layer_model = Model(inputs = model.input,
                    outputs = model.get_layer("conv").input)
output = layer_model.predict(np.expand_dims(np.asarray(input), axis=0))
np.save("input", output)

作为上述示例的输入,并将res1与

进行比较
layer_model = Model(inputs = model.input,
                    outputs = model.get_layer("conv").output)
output = layer_model.predict(np.expand_dims(np.asarray(input), axis=0))
np.save("output", output)

我得到不同的结果

0 个答案:

没有答案