我试图通过在网络的第一卷积层上附加一个辅助分类器,然后过滤掉出现在正类别中的概率小于某个阈值的样本,来改善模型的推理时间。
我在Keras中尝试通过将一个模型分成两个模型来进行此操作,但是与不使用过滤功能的等效单个模型相比,拆分模型的推理速度较慢。
X = tf.placeholder(tf.float32, shape = (None,150,150,3))
y = tf.placeholder(tf.int64, shape=(None))
conv1 = tf.layers.conv2d(
inputs = X,
filters = 16,
kernel_size = 3)
pool1 = tf.layers.max_pooling2d(
inputs = conv1,
pool_size = 2,
strides = 2)
aux_in = tf.layers.flatten(pool1)
aux_in = tf.layers.dense(inputs = aux_in1, units = 1)
aux_out = tf.add(aux_in, 0.3)
aux_out = tf.clip_by_value(aux_out, 0, 1)
mask = tf.ones_like(aux_out)
sliced = tf.boolean_mask(pool1, mask)
conv2 = tf.layers.conv2d(
inputs = sliced,
filters = 16,
kernel_size = 3)
pool2 = tf.layers.max_pooling2d(
inputs = conv2,
pool_size = 2,
strides = 2)
pool2_flat = tf.layers.flatten(pool2)
dense = tf.layers.dense(
pool2_flat,
units = 32,
activation = tf.nn.relu)
logits = tf.layers.dense(inputs = dense, units=1)
这会导致以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-7-3b8c62bb4df4> in <module>()
21
22 mask = tf.ones_like(aux_out)
---> 23 sliced = tf.boolean_mask(pool1, mask)
24
25 conv2 = tf.layers.conv2d(
/home/greg/.local/share/canopy/edm/envs/User/lib/python3.5/site-
packages/tensorflow/python/ops/array_ops.py in boolean_mask(tensor, mask, name, axis)
1320 " are None. E.g. shape=[None] is ok, but shape=None is not.")
1321 axis = 0 if axis is None else axis
-> 1322 shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
1323
1324 leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])
/home/greg/.local/share/canopy/edm/envs/User/lib/python3.5/site-packages/tensorflow/
python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
1021 """
1022 if not self.is_compatible_with(other):
-> 1023 raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1024
1025 def most_specific_compatible_shape(self, other):
ValueError: Shapes (?, 74) and (?, 1) are incompatible
假设进入第一卷积层的数据是[image1,image2,image3,image4],我想过滤掉辅助分类器说不太可能包含正类的图像。这应该导致类似[image1,image3,image4]。