如何在Keras中将回归输出限制在0到1之间

时间:2018-08-04 08:39:47

标签: python tensorflow machine-learning keras conv-neural-network

我正在尝试检测图像中单个对象的单个像素位置。我有一个keras CNN回归网络,其图像张量作为输入,而3项向量作为输出。

第一项:是1(如果找到对象)还是0(未找到对象)

第二项:是介于0和1之间的数字,表示对象沿x轴的距离

第三项:是介于0和1之间的数字,表示对象沿y轴的距离

我已经在2000个测试图像和500个验证图像上训练了网络,并且val_loss远远小于1,并且val_acc最好在0.94左右。很棒。

但是当我预测输出时,我发现所有三个输出项的值都不在0和1之间,它们实际上大约在-2和3之间。这三个项目都应在0到1之间。

我没有在输出层上使用任何非线性激活函数,并且对所有非输出层都使用了relus。即使它是非线性的,我也应该使用softmax吗?第二和第三项预测图像的x和y轴,在我看来,它们是线性量。

这是我的keras网络:

inputs = Input((256, 256, 1))

base_kernels = 64 

 # 256
conv1 = Conv2D(base_kernels, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv2D(base_kernels, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
conv1 = BatchNormalization()(conv1)
conv1 = Dropout(0.2)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

# 128
conv2 = Conv2D(base_kernels * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv2D(base_kernels * 2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization()(conv2)
conv2 = Dropout(0.2)(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

# 64
conv3 = Conv2D(base_kernels * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv2D(base_kernels * 4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization()(conv3)
conv3 = Dropout(0.2)(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

flat = Flatten()(pool3)

dense = Dense(256, activation='relu')(flat)
output = Dense(3)(dense)

model = Model(inputs=[inputs], outputs=[output])

optimizer = Adam(lr=1e-4)
model.compile(optimizer=optimizer, loss='mean_absolute_error', metrics=['accuracy'])

有人可以帮忙吗?谢谢! :) 克里斯

1 个答案:

答案 0 :(得分:4)

S型激活会产生零到一之间的输出,因此,如果将它用作最后一层的激活(输出),则网络的输出将介于零到一之间。

  @result = HTTParty.post(
    'https://test.com/search', 
    :body => [...]