Tensorflow:使用tf.train.batch批处理标签

时间:2017-11-19 06:34:44

标签: tensorflow

我有一段代码

la=[0,0,0,0,0,0,1,1,1,1]
onehot = tf.one_hot(la, depth=2)   #[[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]]
image_batch,labels_batch=tf.train.batch([resized_image,onehot],batch_size=2,num_threads=1)

我跑的时候

  

打印(s.run([tf.shape(image_batch),labels_batch]))

它一次批量处理所有实验室,如

[array([ 2, 50, 50,  3]), array([[[ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 0.,  1.],
    [ 0.,  1.],
    [ 0.,  1.],
    [ 0.,  1.]],

   [[ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 1.,  0.],
    [ 0.,  1.],
    [ 0.,  1.],
    [ 0.,  1.],
    [ 0.,  1.]]], dtype=float32)]

它应输出类似

的内容
[array([ 2, 50, 50,  3]), array([[[ 1.,  0.],

   [[ 1.,  0.]]], dtype=float32)]
不是吗?批量大小为2,拍摄2张图像,一次是相应的标签。 我是CNN和机器学习的新手。事先谢谢。

1 个答案:

答案 0 :(得分:1)

根据 tf.train.batch https://www.tensorflow.org/api_docs/python/tf/train/batch)的 Tensorflow 文档,

enter image description here

由于 enqueue_many = False 默认情况下输入 onehot 的形状为 [10,2] ,因此输出(此处) labels_batch )形状变为 [batch_size,10,2]

如果 enqueue_many = True ,则只有输出(此处 labels_batch )将变为 [batch_size,2]

希望这有帮助。