如何在TensorFlow中实现二值化器层?

时间:2016-10-05 07:46:30

标签: tensorflow

我正在尝试在paper的第4页中实现二进制文件。功能并不太难。就是这样:

enter image description here

此功能无需反向传播渐变。我想在TensorFlow中做这件事。有两种方法可以解决这个问题:

  1. 使用TensorFlow在C ++中实现它。但是,instructions对我来说还不太清楚。如果有人能引导我通过它会很棒。我不清楚的一件事是为什么ZeroOutOp的渐变在Python中实现?
  2. 我决定采用纯Python方法。
  3. 以下是代码:

    import tensorflow as tf
    import numpy as np
    
    def py_func(func, inp, out_type, grad):
        grad_name = "BinarizerGradients_Schin"
        tf.RegisterGradient(grad_name)(grad)
        g = tf.get_default_graph()
        with g.gradient_override_map({"PyFunc": grad_name}):
            return tf.py_func(func, inp, out_type)
    
    '''
    This is a hackish implementation to speed things up. Doesn't directly follow the formula.
    '''
    def _binarizer(x):
        probability_matrix = (x + 1) / float(2)
        probability_matrix = np.matrix.round(probability_matrix, decimals=0)
        np.putmask(probability_matrix, probability_matrix==0.0, -1.0)
        return probability_matrix
    
    def binarizer(x):
        return py_func(_binarizer, [x], [tf.float32], _BinarizerNoOp)
    
    def _BinarizerNoOp(op, grad):
        return grad
    

    问题发生在这里。输入是32x32x3 CIFAR图像,它们在最后一层减少到4x4x64。我的最后一层的形状是(?,4,4,64),在哪里?是批量大小。通过调用:

    完成此操作后
    binarized = binarizer.binarizer(h_pool3)
    h_deconv1 = tf.nn.conv2d_transpose(h_pool3, W_deconv1, output_shape=[batch_size, img_height/4, img_width/4, 64], strides=[1,2,2,1], padding='SAME') + b_deconv1
    

    发生以下错误:

      

    ValueError:形状(4,4,64)和(?,4,4,64)不兼容

    我可以猜到为什么会这样。的?代表批量大小,在最后一层通过二值化器后,?维度似乎消失了。

1 个答案:

答案 0 :(得分:0)

我认为您可以按照in this answer所述进行操作。适用于我们的问题:

def binarizer(input):
    prob = tf.truediv(tf.add(1.0, input), 2.0)
    bernoulli = tf.contrib.distributions.Bernoulli(p=prob, dtype=tf.float32)
    return 2 * bernoulli.sample() - 1

然后,您在哪里设置网络:

W_h1, bias_h1 = ...
h1_before_bin = tf.nn.tanh(tf.matmul(x, W_h1) + bias_h1)

# The interesting bits:
t = tf.identity(h1_before_bin)
h1 = t + tf.stop_gradient(binarizer(h1_before_bin) - t)

但是,我不确定如何验证这是否有效......