如何格式化Keras上的训练输入和输出数据

时间:2017-04-10 16:40:46

标签: python numpy keras heatmap

我是深度学习的新手,我在Keras上遇到了一些数据格式。我的CNN基于A.Newell等人的Stacked Hourglass Networks for Human Pose Estimation

在此网络上,输入为256x256 RGB图像,输出应为64x64热图,突出显示身体关节(肩部,膝部......)。我设法建立网络,我有所有数据(图像)及其注释(身体关节的像素标签)。我想知道如何格式化训练集的输入和输出数据来训练我的模型。目前我使用numpy数组(256,256,3)作为图像,我不知道如何格式化我的输出。我应该创建一个表[n,64,64,7]吗? (n是训练集的大小,7是我用来获得7个关节热图的滤波器数量)

感谢您的时间。

2 个答案:

答案 0 :(得分:0)

输出也可以是一个numpy数组。 考虑这个例子: 训练集:50张图片,大小256x256x3。这可以组合成单个numpy形状的阵列(50,256,256,3)。 类似的格式化输出数据的方法。 示例代码如下:

    #a, b and c are arrays of size 256x256x3
    import numpy as np

    temp = []
    temp.append(a)
    temp.append(b)
    temp.append(c)
    output_labels = []
    output_labels = np.stack(temp)

output_labels数组的形状(3x256x256x3)。

答案 1 :(得分:0)

Keras建议创建数据生成器,以将训练数据和地面真实情况馈入网络。 特定于堆叠式沙漏网络案例,您可以参考我的实现以获取详细信息https://github.com/yuanyuanli85/Stacked_Hourglass_Network_Keras/tree/master/src/data_gen