Tensorflow估算器:1.5中的特征列大小错误

时间:2018-02-14 06:48:50

标签: python tensorflow

尝试使用预测估算器测试我的数据集时,收到此错误:

Traceback (most recent call last):
  File "estimator.py", line 61, in <module>
    model.train(input_fn=train_input_fn, steps=20000)
  File "/home/sid/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 314, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/sid/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 743, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/home/sid/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 725, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/home/sid/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 324, in _model_fn
    config=config)
  File "/home/sid/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 153, in _dnn_model_fn
    'Given type: {}'.format(type(features)))
ValueError: features should be a dictionary of `Tensor`s. Given type: <class 'tensorflow.python.framework.ops.Tensor'>

我在网上发现了同样的错误here但似乎已经在Tensorflow 1.5中解决了。但是,我仍然收到相同的错误,并使用tensorflow 1.5(CUDA 9.0和最新的cuDNN)。这表明该问题很可能与我使用新数据集api创建数据集的方式有关。

estimator.py:

import tensorflow as tf
import numpy as np
from input_pipe import tr_data, val_data

# Parameters
learning_rate = 1e-4
num_epochs = 2000
batch_size = 128
display_step = 100

# We know that the images are 128x128
img_size = 128

# Images are stored in one-dimensional arrays of this length.
img_size_flat = img_size * img_size

# Tuple with height and width of images used to reshape arrays.
img_shape = (img_size, img_size)

# Number of colour channels for the images: 1 channel for gray-scale.
num_channels = 3

# Number of classes, Dogs and Cats
num_classes = 2

# We import our Dataset objects from input_pipe because we need to be
# able to access them from within our input functions
train_dataset = tr_data
val_dataset = val_data

def train_input_fn():
        tr_data = train_dataset.batch(batch_size)
        tr_data = tr_data.repeat(num_epochs)
        # iterator = tr_data.make_one_shot_iterator()
        iterator = tf.data.Iterator.from_structure(tr_data.output_types, tr_data.output_shapes)

        features, labels = iterator.get_next()

        return features, labels

def test_input_fn():
        #iterator = val_dataset.make_one_shot_iterator()
        iterator = tf.data.Iterator.from_structure(tr_data.output_types, tr_data.output_shapes)

        features, labels = iterator.get_next()

        return features, labels


feature_x = tf.feature_column.numeric_column('tr_data.train_imgs',img_shape)


num_hidden_units = [512, 256, 128]

model = tf.estimator.DNNClassifier(feature_columns=[feature_x],
                hidden_units=num_hidden_units,
                activation_fn=tf.nn.relu,
                n_classes=num_classes,
                model_dir="./checkpoints")

model.train(input_fn=train_input_fn, steps=20000)

input_pipe.py为train和val图像创建张量流数据集。我很确定我的数据集工作正常,因为我能够初始化迭代器,遍历整个集合,并打印出相关数据。

我认为问题在于罐装估算器的功能列arg,但我不确定我能做些什么来修复它。我还尝试在iterator.get_next()train_input_fn中返回test_input_fn,但这并没有改变错误。

我们非常感谢任何帮助,我发现的关于在线估算器的唯一例子是那些处理已经制作的数据集的例子,例如tensorflow mnist,而不是某人从头创建的数据集(使用数据集api)。谢谢!

1 个答案:

答案 0 :(得分:1)

Allen Lavoie是正确的,我将train_input_fn函数更改为:

def train_input_fn():
    train_dataset = tr_data.batch(batch_size)
    # train_dataset = train_dataset.repeat()
    # train_dataset = train_dataset.shuffle(buffer_size=batch_size) 

    iterator = train_dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    x = {'image': features}
    y = labels

    return x, y

它有效。我需要返回一个图像的字典,使用我在功能列实例化中使用的相同键值:

feature_image = tf.feature_column.numeric_column('image', shape=[16384*3], dtype=tf.float32)
num_hidden_units = [512, 256, 128]
model = tf.estimator.DNNClassifier(feature_columns=[feature_image],
                hidden_units=num_hidden_units,
                activation_fn=tf.nn.relu,
                n_classes=num_classes,
                model_dir="./checkpoints")

希望这有助于某人!