如何在Tensorflow中使用没有类的预训练模型?

时间:2018-06-01 15:10:05

标签: python tensorflow keras resnet

我尝试使用预训练网络,例如tf.keras.applications.ResNet50,但我有两个问题:

我只想获得网络末端的顶层嵌入层,因为我不想做任何图像分类。所以由于这个原因,我认为不需要课程编号。

  • tf.keras.applications.ResNet50采用默认参数'classes=1000'

    • 有没有办法省略这个参数?
  • 我的输入图片为128*128*1像素,而非224*224*3

    • 修复输入数据形状的最佳方法是什么?

我的目标是使用resnet网络的输出建立三元组丢失网络。

非常感谢!

2 个答案:

答案 0 :(得分:1)

  • ResNet50完全具有参数include_top - 将其设置为False以跳过上一个完全连接的图层。 (然后输出长度为2048的特征向量)。
  • 缩小图片尺寸的最佳方法是重新取样图片,例如使用专用函数tf.image.resample_images

    • 另外,我最初没有注意到你的输入图像只有三个频道,thx @Daniel。我建议你在GPU上构建你的3通道灰度图像(而不是使用numpy在主机上),以避免使用tf.tile将数据传输增加到GPU内存的三倍:

      im3 = tf.tile(im, (1, 1, 1, 3))
      

答案 1 :(得分:1)

作为对方答案的补充。您还需要使您的图像有三个通道,虽然技术上不是Resnet的最佳输入,但它是最简单的解决方案(如果您访问源代码并自行更改输入形状,也可以选择更改Resnet模型)。

使用numpy在三个频道中打包图像:

images3ch = np.concatenate([images,images,images], axis=-1)