如何在Tensorflow中使用多个输入来应用tf.map_fn

时间:2019-06-19 09:03:55

标签: python tensorflow deep-learning tensor turning-point

我尝试使用Pycharm中的Tensorflow将tf.map_fn转换为多个输入。
但是,当我尝试这样做时,
我收到错误消息:TypeError:testzz()缺少1个必需的位置参数:“ data”

如何解决此问题? 或如何获取idxCut的大小以使用for循环?

开发内容。

  • 在数据中找到与阈值相对应的索引(idxCut)。
  • 检查与idxCut对应的数据是否为TPR。

我想使用for循环在数据中找到关于idxCut的TPR(转折点比率)。
我使用了for循环来获取idx,idx-1和idx + 1之间的TPR。
我想发现data [idx]高于其他data [idx-1,idx + 1]。

    def testtt(data):
        ### Cut-off Threshold
        newData = data[5:num_input - 5]   # shape = [1, 100]
        idxCut = tf.where(newData > cutoff) + 5
        idxCut = tf.squeeze(idxCut)   
        # The size of idxCut is always variable. shape = [1, 10] or shape = [1, 27] or etc

        tq = tf.map_fn(testzz, (idxCut, data), dtype=tf.int32)
        print('tqqqq ', tq)
    def testzz(idxCut, data):
        v1 = tf.where(data[idxCut] > data[idxCut - 1], 1, 0)
        v2 = tf.where(data[idxCut] > data[idxCut + 1], 1, 0)
        return tf.where(v1 + v2 > 1, 1, 0)
Traceback (most recent call last):
  File "D:/PycharmProject/Test_DCGAN_BioSignal/test_xcorr_all.py", line 263, in <module>
    tprX = testtt(zX)
  File "D:/PycharmProject/Test_DCGAN_BioSignal/test_xcorr_all.py", line 149, in testtt
    tq = tf.map_fn(testzz, (idxCut, data), dtype=tf.int32)
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\functional_ops.py", line 494, in map_fn
    maximum_iterations=n)
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3291, in while_loop
    return_same_structure)
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3004, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2939, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3260, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "C:\Users\UserName\Anaconda3\envs\TSFW_pycharm\lib\site-packages\tensorflow\python\ops\functional_ops.py", line 483, in compute
    packed_fn_values = fn(packed_values)

TypeError: testzz() missing 1 required positional argument: 'data'

1 个答案:

答案 0 :(得分:0)

当您给tf.map_fn赋予多个张量时,它们的元素不作为独立参数传递给给定函数,而是作为元组传递。这样做:

def testzz(inputs):
    idxCut, data = inputs
    v1 = tf.where(data[idxCut] > data[idxCut - 1], 1, 0)
    v2 = tf.where(data[idxCut] > data[idxCut + 1], 1, 0)
    return tf.where(v1 + v2 > 1, 1, 0)