如何在一个数据记录上预测是否使用shuffle_batch来训练张量流中的模型

时间:2017-11-07 08:54:56

标签: tensorflow

我使用CNN训练了一个模型,使用shuffle_batch来处理大数据文件,然后在训练之前设置批量大小为64。似乎批次大小在培训期间或之后无法更改,那么如何使用训练模型仅预测一个具有固定批量大小的数据记录?

batch_size的使用占位符,代码如下:

def train(target_path, vocab_processor):
    with tf.Graph().as_default():
        **batch_size = tf.placeholder(tf.int32, name='batch_size')**
        data_batch, label_batch = read_data_from_tfrecords(target_path, batch_size)
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            cnn = TextCNN(
                sequence_length=data_batch.shape[1],
                num_classes=label_batch.shape[1],
                vocab_size=len(vocab_processor.vocabulary_),
                embedding_size=FLAGS.embedding_dim,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=FLAGS.num_filters,
                input_x=data_batch,
                input_y=label_batch,
                l2_reg_lambda=FLAGS.l2_reg_lambda
            )
            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)
            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))
            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", cnn.loss)
            acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)
            # Train Summaries
            train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            sess.run(init)
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                while not coord.should_stop():
                    **feed_dict = {
                        cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        batch_size: 64
                    }**
                    _, step, summaries, loss, accuracy = sess.run(
                        [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], feed_dict)
                    time_str = datetime.datetime.now().isoformat()
                    print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                    train_summary_writer.add_summary(summaries, step)
                    current_step = tf.train.global_step(sess, global_step)
                    if current_step % FLAGS.checkpoint_every == 0:
                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))
            except tf.errors.OutOfRangeError:
                print("done training")
            finally:
                coord.request_stop()
            coord.join(threads)
            sess.close()

错误:

    Traceback (most recent call last):
      File "/home/ubuntu/Documents/code/error-classify/cnn_classify/test_train.py", line 247, in <module>
        train(tfRecorder_path, vocab_processor)
      File "/home/ubuntu/Documents/code/error-classify/cnn_classify/test_train.py", line 82, in train
        num_threads=2)
      File "/home/ubuntu/.pyenv/versions/3.5.3/lib/python3.5/site-packages/tensorflow/python/training/input.py", line 1220, in shuffle_batch
        name=name)
      File "/home/ubuntu/.pyenv/versions/3.5.3/lib/python3.5/site-packages/tensorflow/python/training/input.py", line 765, in _shuffle_batch
        if capacity <= min_after_dequeue:
      File "/home/ubuntu/.pyenv/versions/3.5.3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 499, in __bool__
        raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

2 个答案:

答案 0 :(得分:0)

您可以通过占位符替换固定批量大小,将其设置为64以进行培训以及在推断时所需的任何内容。

batch_size = tf.placeholder(tf.int32, (), name="batch_size")
tf.train.shuffle_batch (..., batch_size = batch_size, ...)

答案 1 :(得分:0)

使用此集allow_smaller_final_batch=True解决问题。 通常在测试时应使用train.batch而不是shuffle_batch。

使用占位符时失败,还没弄清楚原因