我正在尝试使用tf.py_func()将python函数包装到tensorflow中并获取一个我无法理解的InvalidArgumentError。 我正在通过两个二维张量,函数返回一个浮点值。
答案 0 :(得分:0)
如果没有distcorr() function
的代码,很难确定,但似乎正如错误所述,该函数返回double
/ float64
而您告诉tf.py_func()
期待float32
(参见tf.float32
参数)。
修改您的函数以在返回结果之前投射结果(例如your_result.astype(numpy.float32)
或将dtype
tf.py_func()
参数更改为tf.float64
。