IndexError:元组索引超出范围CNN张量流

时间:2018-06-19 18:59:25

标签: tensorflow training-data convolutional-neural-network tensorflow-estimator

我正在张量流中的卷积网络上工作,但出现以下错误:

IndexError: tuple index out of range

当我将数据输入模型进行训练时。以下是main功能代码。我认为我的错误是在y(标签)中。我认为问题在于我可以从input_pipeline函数中获取标签的格式,但是我不知道如何解决它。

def main(unused_argv):

    images_batch,labels_batch=input_pipeline(train_path,batch_size,num_epochs)

    with tf.Session() as sess:

        #Inicializamos las variables
        init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)

        #Corremos las filas(queue) que se crearon en el grafico computacional 
        tf.train.start_queue_runners(sess=sess)

        detector=tf.estimator.Estimator(model_fn=cnn_model,model_dir="/tmp/gun_cnn_detector")

        train_fn=tf.estimator.inputs.numpy_input_fn(
            x={"x":images_batch.eval()},
            y=labels_batch.eval(),
            batch_size=10,
            num_epochs=None,
            shuffle=True)
         detector.train(
            input_fn=train_fn,
            steps=2)

        writer = tf.summary.FileWriter('.')
        writer.add_graph(tf.get_default_graph())


def input_pipeline(filenames,batch_size,num_epochs):
    filename_queue=tf.train.string_input_producer([filenames],num_epochs=num_epochs,shuffle=True)
    images,labels=read_file(filename_queue)

    return images,labels

我尝试了很多事情。以下是read函数。我还有一个问题:解码图像时,正确的格式是什么? float32还是uint8

def read_file(filename_queue):

    #Funcion para leer el archivo tf.record, y retornamos el next recrod
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)

    #Se decodifica el tf.record retornando un diccionario 
    feature={'train/image':tf.FixedLenFeature([],tf.string),
             'train/label':tf.FixedLenFeature([],tf.int64)}
    features=tf.parse_single_example(serialized_example,features=feature)

    #Convertimos el string a numeros de los decodificados features
    image=tf.decode_raw(features['train/image'],tf.float32)* (1 / 255.0)

    #Convertimos a datos
    label=tf.cast(features['train/label'],dtype=tf.int32)

    #Reshape data
    image=tf.reshape(image,[224,224,3]) 

    return image,label

这是我得到的错误:

/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Traceback (most recent call last):
  File "/Users/David/Desktop/David/General/Tesis/Practica/Programas/CNN/CNN.py", line 113, in <module>
    tf.app.run()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "/Users/David/Desktop/David/General/Tesis/Practica/Programas/CNN/CNN.py", line 104, in main
    steps=2)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 853, in _train_model_default
    input_fn, model_fn_lib.ModeKeys.TRAIN))
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 691, in _get_features_and_labels_from_input_fn
    result = self._call_input_fn(input_fn, mode)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 798, in _call_input_fn
    return input_fn(**kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/inputs/numpy_io.py", line 175, in input_fn
    if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/inputs/numpy_io.py", line 175, in <genexpr>
    if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
IndexError: tuple index out of range
[Finished in 3.9s with exit code 1]

0 个答案:

没有答案