使用tf.data使用我自己的图像进行自定义训练

时间:2020-05-11 07:09:16

标签: python tensorflow machine-learning keras

我是tensorflow的新手,我无法将自定义数据提供给keras模型。

我已遵循本指南:Load images将我的.jpg文件转换为tf.data。

现在,我将数据转换为(image_batch,label_batch)。 image_batch是形状为(32,224,224,3)的EagerTensor,label_batch是形状为(32,2)的EagerTensor。

然后我找到了本指南:Custom training: walkthrough,但是行会中的数据被转换为形状(32,4)的EagerTensor。

我在执行代码时得到警告:

model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(3,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(3)
])
predictions = model(image_batch)
WARNING:tensorflow:Model was constructed with shape (None, 3) for input Tensor("dense_input:0", shape=(None, 3), dtype=float32), but it was called on an input with incompatible shape (32, 224, 224, 3).

我应该如何调整模型或如何处理数据?

编辑:

模型现在可以使用了,但是还有一个问题。

当我运行以下代码时:

print("Prediction: {}".format(tf.argmax(predictions, axis=1)))
print("    Labels: {}".format(labels_batch))

它打印:

Prediction: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
    Labels: [[ True False]
 [False  True]
 [ True False]
 [False  True]
 [ True False]...(omitted)]

但是我希望它打印出类似这样的内容:

Prediction: [0 1 0 1 1 1 0 1 0 1 1 0 0 0 0 0 1 1 0 1 0 0 1 0 0 0 0 1 0 0 1 0]
    Labels: [2 0 2 0 0 0 1 0 2 0 0 1 1 2 2 2 1 0 1 0 1 2 0 1 1 1 1 0 2 2 0 2]

带有标签的整数一维数组。

我想知道预测是否全部为1是否正常?我该怎么办?

2 个答案:

答案 0 :(得分:0)

您输入的是32个形状为(224,224,3)而不是(3,)的图像。您的输入形状必须为(224,224,3)。

我还注意到您的输出形状看起来也将是(224,224,3),这与您的标签不匹配。您需要在某个时候展平数据或执行类似操作。

let jsonObject = {};  
selectedLanguageDetails.forEach((value, key) => {  
       jsonObject[key] = value  
});  
string jsonString = JSON.stringify(jsonObject);

答案 1 :(得分:0)

Danse图层的输入形状应具有尺寸(None,n),其中None是batch_size。对于您的情况,如果要使用密集层,则应首先使用Flatten层,将图像滚动成(32, 224 * 224 * 3)的形状。该代码应为:

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(), 
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(3)
])

有关更多详细信息,请参见https://www.tensorflow.org/api_docs/python/tf/keras/layers/Flatten