如何格式化目标以训练Keras CNN模型?

时间:2018-04-27 03:22:35

标签: python tensorflow computer-vision keras

我的数据集是鲸鱼的图像。我正在尝试训练一个可以在给定图像中找到2个鲸鱼点的CNN。我的训练特征是numpy数组中的图像,目标是图像中2点的x和y坐标(鲸鱼上的2个点)。

使用Keras制作神经网络的最佳方法是什么,可以从我拥有的数据集中学习,以便能够在新的未标记图像上找到这些点?

目前我的主要问题是弄清楚如何格式化目标(图像中的2个点),以便我的Keras模型可以理解/读取数据。

我破坏的代码是:

x_train = np.array([cv2.imread("1small.jpg")])
y_train = np.array([14.1, 13.5, 16.3, 14.1])

x_test = np.array([cv2.imread("0small.jpg")])
y_test = np.array([11.8, 10.8, 17.0, 16.0]) # fake data just to test

model = Sequential()
model.add(Dense(1,32,32,3))
model.add(Activation('tanh'))
model.add(Dense(1))
model.compile(loss='mean_absolute_error', optimizer='rmsprop')

model.fit(x_train, y_train, nb_epoch=1, batch_size=1)

prediction = model.predict(x_test)
print prediction

1 个答案:

答案 0 :(得分:0)

这些网络非常适合学习坐标,因此您可以将数据保留为四个坐标。然后,您可以将损失定义为距离平方的总和。关于这一点的第一件好事是,它会更多地处理大错误,而较小的错误会更少,所以它会让网络顺利完成。 :)第二件好事是,如果你考虑一下,两点的平方距离是坐标差异的平方和。所以基本上你可以使用坐标数组上的均方误差作为该网络的损失。很好,很容易。

查看您的网络,您应该在前面添加一些卷积和池化层,密集层不是第一个看到图像的好图层。他们应该走向网络的顶端。

此外,tanh激活是过去十年。 :)使用relu,但最后一层除外,它应该没有激活。

还要考虑变换坐标,使图像的中间是原点,边缘大约为-1和1.这会使预测空间标准化,可能会有所帮助。

希望这有帮助。