我使用tflearn培训师跟踪了这个example并对此进行了编码:
image_paths, labels = dataset_utils.read_dataset_list('../test/dummy_labels_file.txt')
data_dir = "../test/dummy_data/"
images = dataset_utils.read_images(data_dir=data_dir, image_paths=image_paths, image_extension='png')
print('Done reading images')
images = dataset_utils.resize(images, (1596, 48))
images = dataset_utils.transpose(images)
labels = dataset_utils.encode(labels)
x_train, x_test, y_train, y_test = dataset_utils.split(features=images, test_size=0.5, labels=labels)
... # parameters initiailized here
with tf.Graph().as_default():
X = tf.placeholder(tf.float32, [None, None, num_features])
Y = tf.placeholder(tf.int32)
sparse_Y = network_utils.dense_to_sparse(Y, num_classes)
seq_lens = tf.placeholder(tf.int32, [None])
def dnn(x):
layer = network_utils.bidirectional_grid_lstm(inputs=x, num_hidden=num_hidden_units)
layer = network_utils.get_time_major(inputs=layer, batch_size=network_utils.get_shape(x)[0],
num_classes=num_classes, num_hidden_units=num_hidden_units * 2)
return layer
net = dnn(X)
cost = network_utils.cost(network_utils.ctc_loss(inputs=net, labels=sparse_Y, sequence_length=seq_lens))
optimizer = network_utils.get_optimizer(learning_rate=learning_rate, optimizer_name=optimizer_name)
train_op = tflearn.TrainOp(loss=cost, optimizer=optimizer)
trainer = tflearn.Trainer(train_ops=train_op)
trainer.fit(feed_dicts={X: x_train, Y: y_train, seq_lens: dataset_utils.get_seq_lens(x_train)},
val_feed_dicts={X: x_test, Y: y_test, seq_lens: dataset_utils.get_seq_lens(x_test)},
n_epoch=1) #error happens here
培训在我运行时开始,但我遇到了这个错误:
Traceback (most recent call last):
File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 53, in <module>
tf.app.run(main=main)
File "...\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 49, in main
n_epoch=1)
File "...\tflearn\helpers\trainer.py", line 338, in fit
show_metric)
File "...\tflearn\helpers\trainer.py", line 817, in _train
feed_batch)
File "...\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "...\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "...\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "...\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[0] does not have value
还有其他人遇到过这个吗?怎么解决这个问题?我确实想使用tflearn培训师来更轻松地训练和测试我的ocr模型,我认为这是我需要修复才能使用它的唯一方法。