如何使用tf.py_func在Keras Lambda层中执行复杂的功能

时间:2018-11-28 20:10:18

标签: python tensorflow keras

我正在用喀拉拉邦编写一个自定义的神经网络,必须将输入传递给一个函数,该函数计算帧列表中各帧之间的光流。但是,当我将代码包装在tensorflow tf.py_func()中时,该函数不会求值并返回None。似乎无法找出一个简单的解决方法,将不胜感激。以下是我正在使用的代码,由于我是Tensorflow的新手,所以希望获得帮助,并想知道如何直接在keras中或在tensorflow中执行上述程序,然后使其与keras兼容:

def optical_flow(frame_list):
    mid = len(frame_list)//2
    base_frame = frame_list[mid]
    new_list = frame_list.copy()
    new_list = np.delete(new_list, mid, axis=0)

    def calculate_optical_flow(one, two):
        flow = cv2.calcOpticalFlowFarneback(one, 
                                            two, 
                                            flow=None,
                                            pyr_scale=0.5, levels=1, winsize=15,
                                            iterations=2,
                                            poly_n=5, poly_sigma=1.1, flags=0)

        return flow

    def move_image_based_on_flow(prev_img, flow):
        # Generate a cartesian grid
        height, width = flow.shape[0], flow.shape[1]
        R2 = np.dstack(np.meshgrid(np.arange(width), np.arange(height)))

        # desired mapping is simply the addition of this grid with the flow
        pixel_map = R2 + flow
        pixel_map = pixel_map.astype(np.float32)
        #perform the remapping of each pixel in the original image
        warped_frame = cv2.remap(prev_img, pixel_map[:, :, 0], pixel_map[:, :, 1], cv2.INTER_LINEAR)
        return warped_frame

    warped_frames = []
    for frame in new_list:
        warped_bps_in_frame = []
        for bp_ind, bp_heatmap in enumerate(frame):
            flow = calculate_optical_flow(base_frame[:][bp_ind][:, :, np.newaxis], bp_heatmap[:, :, np.newaxis])
            new_frame = move_image_based_on_flow(bp_heatmap[:, :, np.newaxis], flow)
            warped_bps_in_frame.append(new_frame)
        warped_frames.append(warped_bps_in_frame)
    warped_frames = np.asanyarray(warped_frames, dtype=np.uint8)
    base_frame = base_frame[np.newaxis, :, :, :]
    warped_frames = np.insert(warped_frames, mid, [base_frame], axis=0)

    return warped_frames

def optical_flow_n(x):
    return tf.py_func(optical_flow, [x], tf.float32)

def optical_flow_net(x):
    x = layers.Lambda(optical_flow_n, output_shape=(total_frames, num_bodyparts, frame_height, frame_width))(x)
    x = layers.Conv3D(1, kernel_size=(1,1,total_frames), padding='same')(x)
    return x

optical_flow_input = layers.Input(shape=(total_frames, num_bodyparts, frame_height, frame_width)) 
optical_flow_output = optical_flow_net(optical_flow_input)
model_ofn = models.Model(inputs=[optical_flow_input], outputs=[optical_flow_output])

运行上面的代码时出现以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-1a6519e3ddb3> in <module>()
      1 optical_flow_input = layers.Input(shape=(total_frames, num_bodyparts, frame_height, frame_width))
----> 2 optical_flow_output = optical_flow_net(optical_flow_input)
      3 model_ofn = models.Model(inputs=[optical_flow_input], outputs=[optical_flow_output])

<ipython-input-9-d6cf34c238c9> in optical_flow_net(x)
     46 def optical_flow_net(x):
     47     x = layers.Lambda(optical_flow_n, output_shape=(total_frames, num_bodyparts, frame_height, frame_width))(x)
---> 48     x = layers.Conv3D(1, kernel_size=(1,1,total_frames), padding='same')(x)
     49     return x

~\AppData\Local\conda\conda\envs\DLCdependencies\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
    412                 # Raise exceptions in case the input is not compatible
    413                 # with the input_spec specified in the layer constructor.
--> 414                 self.assert_input_compatibility(inputs)
    415 
    416                 # Collect input shapes to build layer.

~\AppData\Local\conda\conda\envs\DLCdependencies\lib\site-packages\keras\engine\base_layer.py in assert_input_compatibility(self, inputs)
    309                                      self.name + ': expected ndim=' +
    310                                      str(spec.ndim) + ', found ndim=' +
--> 311                                      str(K.ndim(x)))
    312             if spec.max_ndim is not None:
    313                 ndim = K.ndim(x)

ValueError: Input 0 is incompatible with layer conv3d_1: expected ndim=5, found ndim=None

0 个答案:

没有答案