在Keras中仅训练一个网络输出

时间:2016-11-06 06:01:11

标签: neural-network theano keras q-learning

我在Keras有一个有很多输出的网络,但是,我的训练数据一次只提供单个输出的信息。

目前,我的培训方法是对相关输入运行预测,更改我正在训练的特定输出的值,然后进行单个批量更新。如果我正确,这与将所有输出的损失设置为零相同,除了我试图训练的那个。

有更好的方法吗?我已经尝试过班级重量,除了输出训练外我为所有人设定了零重量,但是它没有给我预期的结果吗?

我正在使用Theano后端。

2 个答案:

答案 0 :(得分:0)

为了达到这个目的,我最终使用了'Functional API'。您基本上创建了多个模型,使用相同的图层输入和隐藏图层但不同的输出图层。

例如:

https://keras.io/getting-started/functional-api-guide/

from keras.layers import Input, Dense
from keras.models import Model

# This returns a tensor
inputs = Input(shape=(784,))

# a layer instance is callable on a tensor, and returns a tensor
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions_A = Dense(1, activation='softmax')(x)
predictions_B = Dense(1, activation='softmax')(x)

# This creates a model that includes
# the Input layer and three Dense layers
modelA = Model(inputs=inputs, outputs=predictions_A)
modelA.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
modelB = Model(inputs=inputs, outputs=predictions_B)
modelB.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

答案 1 :(得分:0)

输出多个结果并仅对其中之一进行优化

比方说,您想从多层(也许从某些中间层)返回输出,但是您只需要优化一个目标输出即可。这是您的操作方法:

让我们从这个模型开始吧:

inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)

# you want to extract these values
useful_info = Dense(32, activation='relu', name='useful_info')(x)

# final output. used for loss calculation and optimization
result = Dense(1, activation='softmax', name='result')(useful_info)

编译多个输出,将 extra 输出的损耗设置为None

None用于您不想用于损失计算和优化的输出

model = Model(inputs=inputs, outputs=[result, useful_info])
model.compile(optimizer='rmsprop',
              loss=['categorical_crossentropy', None],
              metrics=['accuracy'])

训练时仅提供 target 输出。跳过 extra 输出:

model.fit(my_inputs, {'result': train_labels}, epochs=.., batch_size=...)

# this also works:
#model.fit(my_inputs, [train_labels], epochs=.., batch_size=...)

一个预测将其全部获取

只有一个模型,您只能运行一次predict才能获得所需的所有输出:

predicted_labels, useful_info = model.predict(new_x)
相关问题