Tensorflow:带有numpy_input_fn

时间:2017-09-10 09:29:57

标签: numpy tensorflow typeerror

我正在编写一个卷积神经网络来对TensorFlow中的图像进行分类,但是存在一个问题:

当我尝试将NumPy平坦图像阵列(3个RGB值从0到255的通道)提供给tf.estimator.inputs.numpy_input_fn时,我收到以下错误:

  TypeError: Failed to convert object of type <class 'dict'> to Tensor. 
  Contents: {'x': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(8, 
  196608) dtype=uint8>}. Consider casting elements to a supported type.

我的numpy_imput_fn看起来像这样:

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'x': train_x},
    y=train_y,
    batch_size=8,
    num_epochs=None,
    shuffle=True)

在该函数的文档中,据说x应该是NumPy数组的字典:

  

x:numpy数组对象的字典。

1 个答案:

答案 0 :(得分:0)

没关系,对于那些遇到同样问题我修理过的人。在我的模型函数中,我有:

input_layer = tf.reshape(features, [-1, 256, 256, 1])

引发了类型错误。要解决此问题,您必须访问&#39; x&#39;功能字典中的键:

input_layer = tf.reshape(features['x'], [-1, 256, 256, 1])