如何调用客户tensorflow op。在Keras?

时间:2018-03-12 06:47:19

标签: tensorflow deep-learning keras

我有一位客户Tensorflow op。在C ++中编写并成功构建以在Tensorflow代码中调用

from libs.customer_op import customer_op
output = customer_op(x, filter=w, rates=[1, 1, rate, rate], padding="SAME", strides=[1, 1, stride, stride])

现在,我正在使用带有Tensorflow后端的Keras。是否可以在Keras中调用我的上述功能。我们需要做一些额外的注册步骤吗?

更新:感谢Matias Valdenegro的建议。我试过了。这是我在tensorflow中的完整代码以及我在Keras中所做的。 -Tensorflow代码

def my_conv(input,num_o,kernel_size, stride):
    num_x = input.shape[3].value
    offset = slim.conv2d(input, 18, [kernel_size, kernel_size], stride=stride, activation_fn=None, scope='offset', normalizer_fn=None)
    w = tf.get_variable('weights', shape=[num_o, num_x, kernel_size, kernel_size],
                    initializer=tf.contrib.layers.xavier_initializer())
    output = customer_conv(x, filter=w, offset=offset,padding="SAME")

-Keras代码:

def my_conv(input, num_o, kernel_size, stride):
   num_x = input.shape[3].value
   offset = KL.Conv2D(18, (kernel_size, kernel_size), strides=(stride,stride))(input)
   w = KI.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)
   output = Lambda(lambda x: deform_conv_op(x, filter=w, offset=offset, padding="SAME"))(input)
   return output

所以,这是我将调用函数的地方

class CustomerCNN():
    def __init__(self, mode):        
        self.mode = mode

    def build(self, mode):

        # Inputs
        input_image = KL.Input(
            shape=config.IMAGE_SHAPE.tolist(), name="input_image")       
        f1 = Lambda(lambda x: my_conv(x, 256, 3, 1))(input_image)

对于上述解决方案,我仍然是问题所在:

  1. 如何在Keras

  2. 中将形状初始为shape=[num_o, num_x, kernel_size, kernel_size]
  3. 如何在课程my_conv中致电我的客户CustomerCNN?我是否还需要一个Lambda函数

1 个答案:

答案 0 :(得分:2)

您可以使用lambda图层调用它:

output = Lambda(lambda x: customer_op(x, filter=w, rates=[1, 1, rate, rate],
                padding="SAME", strides=[1, 1, stride, stride]))(input)