我想使用要素提取器(例如ResNet101),然后在其后添加使用要素提取器层输出的图层。但是,我似乎不知道如何。我只在网上找到了解决方案,其中使用了整个网络,而没有添加其他层。 我对Tensorflow没有经验。
在下面的代码中,您可以看到我尝试过的内容。我可以在没有附加卷积层的情况下正确运行代码,但是我的目标是在ResNet之后添加更多层。 通过尝试添加额外的conv层,将返回此类型错误: TypeError:预期为float32,得到了OrderedDict([('resnet_v1_101 / conv1',...
一旦我添加了更多的层,我想开始在一个很小的测试集上进行训练,以查看我的模型是否可以过拟合。
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import matplotlib.pyplot as plt
numclasses = 17
from google.colab import drive
drive.mount('/content/gdrive')
def decode_text(filename):
img = tf.io.decode_jpeg(tf.io.read_file(filename))
img = tf.image.resize_bilinear(tf.expand_dims(img, 0), [224, 224])
img = tf.squeeze(img, 0)
img.set_shape((None, None, 3))
return img
dataset = tf.data.TextLineDataset(tf.cast('gdrive/My Drive/5LSM0collab/filenames.txt', tf.string))
dataset = dataset.map(decode_text)
dataset = dataset.batch(2, drop_remainder=True)
img_1 = dataset.make_one_shot_iterator().get_next()
net = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8)
net = slim.conv2d(net, numclasses, 1)
sess = tf.Session()
global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
img_out, conv_out = sess.run((img_1, net))
答案 0 :(得分:0)
resnet_v1.resnet_v1_101
不仅仅返回net
,而是返回一个元组net, end_points
。第二个元素是字典,这大概是您收到此特定错误消息的原因。
对于documentation of this function:
返回:
net:大小为[batch,height_out,width_out, channels_out]。如果global_pool为False, 然后将height_out和width_out减少一个 与各自的height_in和width_in相比,output_stride的系数, 否则height_out和width_out等于1。如果num_classes为0或无, 然后net是最后一个ResNet块的输出,可能是在全局之后 平均池化。如果num_classes为非零整数,则net包含 softmax之前的激活。
end_points:网络组件到相应组件的字典 激活。
因此您可以编写例如:
net, _ = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8)
net = slim.conv2d(net, numclasses, 1)
您还可以选择一个中间层,例如:
_, end_points = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8)
net = slim.conv2d(end_points["main_Scope/resnet_v1_101/block3"], numclasses, 1)
(您可以查看end_points以查找端点的名称。您的作用域名称将不同于main_Scope。)