滤除层之间的特定样本

时间:2019-04-12 09:34:06

标签: tensorflow

我试图通过在网络的第一卷积层上附加一个辅助分类器,然后过滤掉出现在正类别中的概率小于某个阈值的样本,来改善模型的推理时间。

我在Keras中尝试通过将一个模型分成两个模型来进行此操作,但是与不使用过滤功能的等效单个模型相比,拆分模型的推理速度较慢。

X = tf.placeholder(tf.float32, shape = (None,150,150,3))
y = tf.placeholder(tf.int64, shape=(None))

conv1 = tf.layers.conv2d(
        inputs = X,
        filters = 16,
        kernel_size = 3)

pool1 = tf.layers.max_pooling2d(
        inputs = conv1,
        pool_size = 2,
        strides = 2)

aux_in = tf.layers.flatten(pool1)
aux_in = tf.layers.dense(inputs = aux_in1, units = 1)

aux_out = tf.add(aux_in, 0.3)
aux_out = tf.clip_by_value(aux_out, 0, 1)

mask = tf.ones_like(aux_out)
sliced = tf.boolean_mask(pool1, mask)

conv2 = tf.layers.conv2d(
        inputs = sliced,
        filters = 16,
        kernel_size = 3)

pool2 = tf.layers.max_pooling2d(
        inputs = conv2,
        pool_size = 2,
        strides = 2)

pool2_flat = tf.layers.flatten(pool2)

dense = tf.layers.dense(
        pool2_flat,
        units = 32,
        activation = tf.nn.relu)

logits = tf.layers.dense(inputs = dense, units=1)

这会导致以下错误:

---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-7-3b8c62bb4df4> in <module>()

     21 

     22 mask = tf.ones_like(aux_out)

---> 23 sliced = tf.boolean_mask(pool1, mask)

     24 

     25 conv2 = tf.layers.conv2d(


/home/greg/.local/share/canopy/edm/envs/User/lib/python3.5/site-
packages/tensorflow/python/ops/array_ops.py in boolean_mask(tensor, mask, name, axis)

   1320           " are None.  E.g. shape=[None] is ok, but shape=None is not.")

   1321     axis = 0 if axis is None else axis

-> 1322     shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)

   1323 

   1324     leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])


/home/greg/.local/share/canopy/edm/envs/User/lib/python3.5/site-packages/tensorflow/
python/framework/tensor_shape.py in assert_is_compatible_with(self, other)

   1021     """

   1022     if not self.is_compatible_with(other):

-> 1023       raise ValueError("Shapes %s and %s are incompatible" % (self, other))

   1024 

   1025   def most_specific_compatible_shape(self, other):


ValueError: Shapes (?, 74) and (?, 1) are incompatible

假设进入第一卷积层的数据是[image1,image2,image3,image4],我想过滤掉辅助分类器说不太可能包含正类的图像。这应该导致类似[image1,image3,image4]。

0 个答案:

没有答案