如何使用张量流后端在Keras的下一层重用最后一层的偏差

时间:2018-08-03 16:11:16

标签: keras

我是Keras的新手 我的神经网络结构在这里: neural network structure

我的想法是:

import keras.backend as KBack
import tensorflow as tf

#...some code here

model = Sequential()
hidden_units = 4
layer1 = Dense(
    hidden_units,
    input_dim=len(InputIndex),
    activation='sigmoid'
)
model.add(layer1)
# layer1_bias = layer1.get_weights()[1][0]

layer2 = Dense(
    1, activation='sigmoid',
    use_bias=False
)
model.add(layer2)
# KBack.bias_add(model.output, layer1_bias[0])

我知道这是行不通的,因为layer1_bias [0]不是张量,但是我不知道如何解决它。或有人有其他解决方案。

谢谢。

1 个答案:

答案 0 :(得分:0)

由于bias_add期望Tensor并向其传递一个浮点数(偏差的实际值),因此您得到错误。另外,请注意,您的隐藏层实际上有3个偏差(每个节点一个)。如果要将第一个节点的偏差添加到输出层,这应该可以工作:

import keras.backend as K
from keras.layers import Dense, Activation
from keras.models import Sequential

model = Sequential()

layer1 = Dense(3, input_dim=2, activation='sigmoid')
layer2 = Dense(1, activation=None, use_bias=False)
activation = Activation('sigmoid')

model.add(layer1)
model.add(layer2)
K.bias_add(model.output, layer1.bias[0:1]) # slice like this to not lose a dimension
model.add(activation)

print(model.summary())

请注意,要“正确”(根据致密层的定义),您应该先添加偏差,然后再添加激活。

此外,您的代码与网络情况并不完全一致。在图中,一个单一的共享偏差被添加到网络中的每个节点。您可以使用功能性API来执行此操作。这个想法是在隐藏层和输出层中禁用偏见的使用,并手动添加一个您自己定义的偏见变量,该变量将由各层共享。我正在为tf.add()使用tensorflow,因为它支持广播:

from keras.layers import Dense, Lambda, Input, Add
from keras.models import Model
import keras.backend as K
import tensorflow as tf

# Define the shared bias as a custom keras variable
shared_bias = K.variable(value=[0], name='shared_bias')

input_layer = Input(shape=(2,))
# Disable biases in the hidden layer
dense_1 = Dense(units=3, use_bias=False, activation=None)(input_layer)
# Manually add the shared bias
dense_1 = Lambda(lambda x: tf.add(x, shared_bias))(dense_1)
# Disable bias in output layer
output_layer = Dense(units=1, use_bias=False)(dense_1)
# Manually add the bias variable
output_layer = Lambda(lambda x: tf.add(x, shared_bias))(output_layer)

model = Model(inputs=input_layer, outputs=output_layer)
print(model.summary())

这假设您的共享偏见是不可训练的。