我尝试使用此代码 - https://github.com/KGPML/Hyperspectral
def run_training():
"""Train MNIST for a number of steps."""
# Get the sets of images and labels for training, validation, and
# test on IndianPines.
"""Concatenating all the training and test mat files"""
for i in range(TRAIN_FILES):
Training_data = input_data.read_data_sets(os.path.join(DATA_PATH, 'Train_'+str(IMAGE_SIZE)+'_'+str(1+1)+'.mat'), 'train')
for i in range(TEST_FILES):
Test_data = input_data.read_data_sets(os.path.join(DATA_PATH, 'Test_'+str(IMAGE_SIZE)+'_'+str(0+1)+'.mat'),'test')
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Generate placeholders for the images and labels.
images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
# Build a Graph that computes predictions from the inference model.
logits = IndianPinesMLP.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2,
FLAGS.hidden3)
# Add to the Graph the Ops for loss calculation.
loss = IndianPinesMLP.loss(labels=labels_placeholder, logits=logits)
# Add to the Graph the Ops that calculate and apply gradients.
train_op = IndianPinesMLP.training(loss, FLAGS.learning_rate)
# Add the Op to compare the logits to the labels during evaluation.
eval_correct = IndianPinesMLP.evaluation(labels=labels_placeholder, logits=logits)
# Build the summary operation based on the TF collection of Summaries.
# summary_op = tf.merge_all_summaries()
# Add the variable initializer Op.
init = tf.initialize_all_variables()
# Create a saver for writing training checkpoints.
saver = tf.train.Saver()
# Create a session for running Ops on the Graph.
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
# summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
# And then after everything is built:
# Run the Op to initialize the variables.
sess.run(init)
# Start the training loop.
for step in xrange(FLAGS.max_steps):
start_time = time.time()
# Fill a feed dictionary with the actual set of images and labels
# for this particular training step.
feed_dict = fill_feed_dict(Training_data,
images_placeholder,
labels_placeholder)
# Run one step of the model. The return values are the activations
# from the `train_op` (which is discarded) and the `loss` Op. To
# inspect the values of your Ops or variables, you may include them
# in the list passed to sess.run() and the value tensors will be
# returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)
duration = time.time() - start_time
# Write the summaries and print an overview fairly often.
if step % 50 == 0:
# Print status to stdout.
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# Update the events file.
# summary_str = sess.run(summary_op, feed_dict=feed_dict)
# summary_writer.add_summary(summary_str, step)
# summary_writer.flush()
# Save a checkpoint and evaluate the model periodically.
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
saver.save(sess, '.\model-MLP-'+str(IMAGE_SIZE)+'X'+str(IMAGE_SIZE)+'.ckpt', global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
Training_data)
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
Test_data)
并收到错误:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-23-0683f80cdbe4> in <module>()
----> 1 run_training()
<ipython-input-22-b34daa52b702> in run_training()
60 feed_dict = fill_feed_dict(Training_data,
61 images_placeholder,
---> 62 labels_placeholder)
63
64 # Run one step of the model. The return values are the activations
如果手动运行这些部件,我没有错误:
<ipython-input-5-f04ef9a1e6b2> in fill_feed_dict(data_set, images_pl, labels_pl)
15 # Create the feed_dict for the placeholders filled with the next
16 # `batch size ` examples.
---> 17 images_feed, labels_feed = data_set.next_batch(batch_size)
18 feed_dict = {
19 images_pl: images_feed,
同样的问题在这里:
~\Path to: \Spatial_dataset.py in next_batch(self, batch_size)
87 start = 0
88 self._index_in_epoch = batch_size
---> 89 assert batch_size <= self._num_examples
90 end = self._index_in_epoch
91 return self._images[start:end], np.reshape(self._labels[start:end],len(self._labels[start:end]))
AssertionError:
当我现在运行run_training()
时,会出现上述错误。
这是什么意思,我该如何解决,谷歌在这种情况下不是一个帮助。 感谢您的帮助。
答案 0 :(得分:0)
主要错误是由于:
---> 89 assert batch_size <= self._num_examples
更改batch_size
,使其成为训练集文件数量的一个因素(未经验证),以及训练集图像总数的一个因素(具有验证)。
例如,如果您的训练集中有100个文件,而validation_size
是0.2,那么将训练80张图像,并使用20张图像进行验证。因此,选择batch_size
使其为80的因数,例如20。20是80的因数以及100的因数。