使用TensorFlow的怪异铸造问题

时间:2019-02-13 09:04:04

标签: python tensorflow casting

当尝试从用户定义的函数输出要由TF包装的浮点数时,使用TF模块遇到奇怪的行为。

import tensorflow as tf
import numpy as np

def my_func(x):
  outX = 1.0
  return outX

input = tf.placeholder(tf.float32,shape=(2,))
resultFun = tf.py_func(my_func, [input], tf.float32)

with tf.Session() as sess:
  inival=[np.array(0.9),np.array(1.9)]
  print(sess.run(resultFun, feed_dict={input: inival})) 

问题出在那儿

outX=1.0

因为它生成以下内容:

  

InvalidArgumentError(请参阅上面的回溯):pyfunc_0返回的第0个值是双精度值,但期望为浮点数

我试图用np.float或float强制转换,但结果相同。 实际上,我还设法克服了一个难题:

outX = x*0+1.0

,并且有效。 因此,这显然取决于某种错误的预测。 如何在没有快捷方式的情况下解决投放问题?

1 个答案:

答案 0 :(得分:0)

通常,TensorFlow尝试使本机Python值适合所请求的数据类型,但有时会感到困惑。在这种情况下,由于Python本身不允许您指定该值是32位还是64位,因此您可以使用NumPy包装它:

import tensorflow as tf
import numpy as np

def my_func(x):
  outX = np.array(1.0, dtype=np.float32)
  return outX

input = tf.placeholder(tf.float32,shape=(2,))
resultFun = tf.py_func(my_func, [input], tf.float32)

with tf.Session() as sess:
  inival=[np.array(0.9),np.array(1.9)]
  print(sess.run(resultFun, feed_dict={input: inival})) 
  # 0.1