从tensorflow队列排序

时间:2017-01-18 11:05:57

标签: tensorflow

我试图更详细地了解队列。使用下面的代码我希望,因为我没有改变字母列表,输出集合将按字母顺序排列。除了最初的时代,这似乎是所有情况。我误解了什么吗?

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import tensorflow as tf
import numpy as np
import string


# Basic model parameters as external flags.
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
                     'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', '/tmp/data',
                    'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
                     'for unit testing.')


def run_training():
  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    with tf.name_scope('input'):
      # Input data
      images_initializer = tf.placeholder(
          dtype=tf.int64,
          shape=[52,1])
      input_images = tf.Variable(
          images_initializer, trainable=False, collections=[])

      image = tf.train.slice_input_producer(
          [input_images], num_epochs=2)
      images = tf.train.batch(
          [image], batch_size=1)

      alph_initializer = tf.placeholder(
          dtype=tf.string,
          shape=[26,1])
      input_alph = tf.Variable(
          alph_initializer, trainable=False, collections=[])

      alph = tf.train.slice_input_producer(
          [input_alph], shuffle=False, capacity=26)
      alphs = tf.train.batch(
          [alph], batch_size=1)


    my_list = np.array(list(range(0,52))).reshape(52,1)
    my_list_val = np.array(list(string.ascii_lowercase)).reshape(26,1)


    # Create the op for initializing variables.
    init_op = tf.initialize_all_variables()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Run the Op to initialize the variables.
    sess.run(init_op)
    sess.run(input_images.initializer,
             feed_dict={images_initializer: my_list})
    sess.run(input_alph.initializer,
             feed_dict={alph_initializer: my_list_val})

    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # And then after everything is built, start the training loop.
    collection = []
    try:
      step = 0
      while not coord.should_stop():
        start_time = time.time()

        # Run one step of the model.
        integer =  sess.run(image)
        #print("Integer val", integer)

        char =  sess.run(alph)
        collection.append(char[0][0])
        print("String val", char)


        duration = time.time() - start_time

    except tf.errors.OutOfRangeError:
      print('Saving')
      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()
    print(str(collection))


    # Wait for threads to finish.
    coord.join(threads)
    sess.close()


def main(_):
  run_training()


if __name__ == '__main__':
    tf.app.run()

1 个答案:

答案 0 :(得分:0)

将上述内容更改为以下内容可以解决我的困惑

 try:
      step = 0
      while not coord.should_stop():
        start_time = time.time()

        # Run one step of the model.
        integer =  sess.run(images)
        #print("Integer val", integer)

        char =  sess.run(alphs)
        collection.append(char[0][0])
        print("String val", char)


        duration = time.time() - start_time

    except tf.errors.OutOfRangeError:
      print('Saving')
      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()
    print(str(collection))