Tensorflow:无法显示jpeg图像

时间:2017-01-31 10:53:20

标签: python machine-learning tensorflow deep-learning

我使用类似于here的脚本将我的数据集转换为分片的tfrecords。但是当我尝试使用低于tensorflow冻结的脚本读取它时,我必须使用kill来杀死进程。 (注意:现在我在CPU模式下工作)

    def parse_example_proto(example_serialized):
    feature_map = {
      'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
                                          default_value=''),
      'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
                                              default_value=-1),
      'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
                                             default_value=''),
  }

    features = tf.parse_single_example(example_serialized, feature_map)

    init_image = tf.image.decode_jpeg(features['image/encoded'], channels = 3)
    init_image.set_shape([800,480,3])
    image = tf.reshape(init_image,tf.pack([800, 480, 3]))
    float_image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    label = tf.cast(features['image/class/label'], dtype=tf.int32)

    return float_image , label, features['image/class/text']


def batch_inputs(batch_size, train,sess, num_preprocess_threads=4,
                 num_readers=1):

    with tf.name_scope('batch_processing'):
        tf_record_pattern = os.path.join('/home/raarora/', '%s-*' % 'train')
        data_files = tf.gfile.Glob(tf_record_pattern)
        if data_files is None:
            raise ValueError('No data files found for this dataset')
#        print data_files
        # Create filename_queue
        if train:
            filename_queue = tf.train.string_input_producer(data_files,
                                                          shuffle=True,
                                                          capacity=8)
        else:
            filename_queue = tf.train.string_input_producer(data_files,
                                                          shuffle=False,
                                                          capacity=1)

        reader =tf.TFRecordReader()
        _, example_serialized = reader.read(filename_queue)

        image, label, _ = parse_example_proto(example_serialized)

        examples_per_shard = 201
        min_queue_examples = examples_per_shard * 2

        images, labels = tf.train.shuffle_batch(
            [image, label], batch_size=batch_size, num_threads=4,
            capacity=min_queue_examples + 3 * batch_size,
            min_after_dequeue=min_queue_examples)
        print images.eval(session=sess)
        return s,images,labels


if __name__ == '__main__':

    sess = tf.Session()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    s,_,_  = batch_inputs(2,1,sess)

1 个答案:

答案 0 :(得分:0)

能够解决这个问题。我认为TFRecord是一种字典,你需要只提供所需的键,但是在给出整个特征图以及稍后如何处理图像的小改动时,它可以工作。

我犯的另一个错误是应该在调用tf.train.shuffle_batch()之后启动queue_runner。我不知道这是我的理解中的错误还是差距

这是用于读取数据的工作代码

SELECT
  pr1.id AS user_id,
  pr1.title AS user_name,
  pr2.id AS liker_id,
  pr2.title AS liker_name,
  x.which AS which_table,
  x.cnt AS total,
FROM 
(
  SELECT rid, rootid, which, COUNT(*) AS cnt
  FROM
  (
    SELECT rid, rootid, 'vote' which FROM p_likes
    UNION ALL 
    SELECT rid, rootid, 'comment' which FROM p_comments
    UNION ALL 
    SELECT rid, rootid, 'friend' which FROM relations
  ) y
  WHERE y.rootid = 1246 AND y.rootid <> y.rid
  GROUP BY y.rid, y.rootid, y.which
) x
INNER JOIN pagesroot pr1 on x.rootid = pr1.id
INNER JOIN pagesroot pr2 on x.rid = pr2.id
ORDER BY 1,2,3,4,5,6;

注意:我不清楚分片记录,因此我只使用了一个分片。

https://agray3.github.io/2016/11/29/Demystifying-Data-Input-to-TensorFlow-for-Deep-Learning.html

的信用