张量流中的批处理结构

时间:2020-04-09 09:06:05

标签: tensorflow deep-learning neural-network mnist

我正在关注带有张量流和MNIST数据集的神经网络教程。我遇到了以下代码:

for _ in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x: batch[0], y: batch[1]})

我在可视化批处理结构时遇到一些问题。特别是批次的索引。 batch[0]是否以某种方式表示该批次中的所有50张图像,而batch[1]表示这些图像的所有50个标签?如果有人可以直观地显示批处理的结构,那将是很好的。我进行了搜索,但找不到关于此的好的教程。

1 个答案:

答案 0 :(得分:1)

这是我用于批量显示图像的基本代码

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), 
color='red' if red else 'black', fontdict={'verticalalignment':'center'}, 
pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(images,labels, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
   """

    # auto-squaring: this will drop data that does not fit into square or square- 
    #rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows

    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # display
    tempo=""
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else tempo 
        correct = True
        if predictions is not None:
        title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)

    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

现在我们分批显示图像

图像包含全部50张图像,标签包含全部50个标签

  display_batch_of_images(images,labels)