我有一位客户Tensorflow op。在C ++中编写并成功构建以在Tensorflow代码中调用
from libs.customer_op import customer_op
output = customer_op(x, filter=w, rates=[1, 1, rate, rate], padding="SAME", strides=[1, 1, stride, stride])
现在,我正在使用带有Tensorflow后端的Keras。是否可以在Keras中调用我的上述功能。我们需要做一些额外的注册步骤吗?
更新:感谢Matias Valdenegro的建议。我试过了。这是我在tensorflow中的完整代码以及我在Keras中所做的。 -Tensorflow代码
def my_conv(input,num_o,kernel_size, stride):
num_x = input.shape[3].value
offset = slim.conv2d(input, 18, [kernel_size, kernel_size], stride=stride, activation_fn=None, scope='offset', normalizer_fn=None)
w = tf.get_variable('weights', shape=[num_o, num_x, kernel_size, kernel_size],
initializer=tf.contrib.layers.xavier_initializer())
output = customer_conv(x, filter=w, offset=offset,padding="SAME")
-Keras代码:
def my_conv(input, num_o, kernel_size, stride):
num_x = input.shape[3].value
offset = KL.Conv2D(18, (kernel_size, kernel_size), strides=(stride,stride))(input)
w = KI.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)
output = Lambda(lambda x: deform_conv_op(x, filter=w, offset=offset, padding="SAME"))(input)
return output
所以,这是我将调用函数的地方
class CustomerCNN():
def __init__(self, mode):
self.mode = mode
def build(self, mode):
# Inputs
input_image = KL.Input(
shape=config.IMAGE_SHAPE.tolist(), name="input_image")
f1 = Lambda(lambda x: my_conv(x, 256, 3, 1))(input_image)
对于上述解决方案,我仍然是问题所在:
如何在Keras
shape=[num_o, num_x, kernel_size, kernel_size]
如何在课程my_conv
中致电我的客户CustomerCNN
?我是否还需要一个Lambda函数
答案 0 :(得分:2)
您可以使用lambda图层调用它:
output = Lambda(lambda x: customer_op(x, filter=w, rates=[1, 1, rate, rate],
padding="SAME", strides=[1, 1, stride, stride]))(input)