如何在Tensorflow中将预训练的网络用作图层?

时间:2019-05-20 12:04:05

标签: tensorflow transfer resnet

我想使用要素提取器(例如ResNet101),然后在其后添加使用要素提取器层输出的图层。但是,我似乎不知道如何。我只在网上找到了解决方案,其中使用了整个网络,而没有添加其他层。 我对Tensorflow没有经验。

在下面的代码中,您可以看到我尝试过的内容。我可以在没有附加卷积层的情况下正确运行代码,但是我的目标是在ResNet之后添加更多层。 通过尝试添加额外的conv层,将返回此类型错误: TypeError:预期为float32,得到了OrderedDict([('resnet_v1_101 / conv1',...

一旦我添加了更多的层,我想开始在一个很小的测试集上进行训练,以查看我的模型是否可以过拟合。


import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import matplotlib.pyplot as plt

numclasses = 17

from google.colab import drive
drive.mount('/content/gdrive')

def decode_text(filename):
  img = tf.io.decode_jpeg(tf.io.read_file(filename))
  img = tf.image.resize_bilinear(tf.expand_dims(img, 0), [224, 224])
  img = tf.squeeze(img, 0)
  img.set_shape((None, None, 3))
  return img

dataset = tf.data.TextLineDataset(tf.cast('gdrive/My Drive/5LSM0collab/filenames.txt', tf.string))
dataset = dataset.map(decode_text)
dataset = dataset.batch(2, drop_remainder=True)

img_1 = dataset.make_one_shot_iterator().get_next()
net = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)


sess = tf.Session()

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
img_out, conv_out = sess.run((img_1, net))

1 个答案:

答案 0 :(得分:0)

resnet_v1.resnet_v1_101不仅仅返回net,而是返回一个元组net, end_points。第二个元素是字典,这大概是您收到此特定错误消息的原因。

对于documentation of this function

  

返回:

     

net:大小为[batch,height_out,width_out,   channels_out]。如果global_pool为False,   然后将height_out和width_out减少一个         与各自的height_in和width_in相比,output_stride的系数,         否则height_out和width_out等于1。如果num_classes为0或无,         然后net是最后一个ResNet块的输出,可能是在全局之后         平均池化。如果num_classes为非零整数,则net包含         softmax之前的激活。

     

end_points:网络组件到相应组件的字典         激活。

因此您可以编写例如:

net, _ = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)

您还可以选择一个中间层,例如:

_, end_points = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(end_points["main_Scope/resnet_v1_101/block3"], numclasses, 1)

(您可以查看end_points以查找端点的名称。您的作用域名称将不同于main_Scope。)