我正在张量流中的卷积网络上工作,但出现以下错误:
IndexError: tuple index out of range
当我将数据输入模型进行训练时。以下是main
功能代码。我认为我的错误是在y
(标签)中。我认为问题在于我可以从input_pipeline
函数中获取标签的格式,但是我不知道如何解决它。
def main(unused_argv):
images_batch,labels_batch=input_pipeline(train_path,batch_size,num_epochs)
with tf.Session() as sess:
#Inicializamos las variables
init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
#Corremos las filas(queue) que se crearon en el grafico computacional
tf.train.start_queue_runners(sess=sess)
detector=tf.estimator.Estimator(model_fn=cnn_model,model_dir="/tmp/gun_cnn_detector")
train_fn=tf.estimator.inputs.numpy_input_fn(
x={"x":images_batch.eval()},
y=labels_batch.eval(),
batch_size=10,
num_epochs=None,
shuffle=True)
detector.train(
input_fn=train_fn,
steps=2)
writer = tf.summary.FileWriter('.')
writer.add_graph(tf.get_default_graph())
def input_pipeline(filenames,batch_size,num_epochs):
filename_queue=tf.train.string_input_producer([filenames],num_epochs=num_epochs,shuffle=True)
images,labels=read_file(filename_queue)
return images,labels
我尝试了很多事情。以下是read
函数。我还有一个问题:解码图像时,正确的格式是什么? float32
还是uint8
?
def read_file(filename_queue):
#Funcion para leer el archivo tf.record, y retornamos el next recrod
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
#Se decodifica el tf.record retornando un diccionario
feature={'train/image':tf.FixedLenFeature([],tf.string),
'train/label':tf.FixedLenFeature([],tf.int64)}
features=tf.parse_single_example(serialized_example,features=feature)
#Convertimos el string a numeros de los decodificados features
image=tf.decode_raw(features['train/image'],tf.float32)* (1 / 255.0)
#Convertimos a datos
label=tf.cast(features['train/label'],dtype=tf.int32)
#Reshape data
image=tf.reshape(image,[224,224,3])
return image,label
这是我得到的错误:
/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Traceback (most recent call last):
File "/Users/David/Desktop/David/General/Tesis/Practica/Programas/CNN/CNN.py", line 113, in <module>
tf.app.run()
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "/Users/David/Desktop/David/General/Tesis/Practica/Programas/CNN/CNN.py", line 104, in main
steps=2)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 853, in _train_model_default
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 691, in _get_features_and_labels_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 798, in _call_input_fn
return input_fn(**kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/inputs/numpy_io.py", line 175, in input_fn
if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/estimator/inputs/numpy_io.py", line 175, in <genexpr>
if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
IndexError: tuple index out of range
[Finished in 3.9s with exit code 1]