从keras fit_generator()过渡到模型输入层的工作原理

时间:2019-07-12 14:41:52

标签: python tensorflow keras

我正在处理图像数据和一些标量元数据(例如头发颜色,眼睛颜色等)。 我正在使用自行编写的生成器来使用Keras .fit_generator()函数。

该过程如下所示:

应用了一些数据增强后,我得到了数据集的形状((10,200,200,3),(10,),(10,),(10,),(10,))(出于想象力:我提取了形状为(200,200,3的图像)并将其中的10个叠加在一起-> (10,200,200,3)。因此,我将元数据复制了10次->每种形状(10,)

然后,我使用张量流函数dataset = dataset.apply(tf.contrib.data.unbatch()),使数据集的形状为((200,200,3),(),(),(),())。现在,我从这里与您共享代码:

编辑(更多代码):

以下代码是我的生成器函数的最后一行,将从.fit_generator()中的main()函数中调用

shape_dataset = tf.shape(dataset) # shape ((10,200,200,3),(10,),(10,),(10,),(10,)) like I mentioned above
dataset = dataset.apply(tf.contrib.data.unbatch()) # shape ((200,200,3),(),(),(),()) like I mentioned a bove 
dataset = dataset.shuffle(buffer_size = buffer_size)
dataset = dataset.batch(batch_size=batch_size) 
dataset = dataset.repeat()
iterator_all = dataset.make_one_shot_iterator()
next_all = iterator_all.get_next()

with tf.Session() as sess:
    while True:
        try:
            image, eye_color, hair_ color, labels = sess.run(next_all)
            yield [image, eye_color, hair_ color], labels

        except tf.errors.OutOfRangeError:
            print('Finished')
            break

现在该张量将通过keras .fit_generator()函数输入到我的网络中。 输入层如下所示:

input_image = Input(shape=(200, 200, 3))
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')

现在我有一个问题:

  1. 来自((10,200,200,3),(10,),(10,),(10,),(10,))的10个字符在哪里通过tf.contrib.data.unbatch())函数?对我来说,我好像失去了这10个值而只得到1个?

  2. fit_generator()函数可以批量运行,但是如何?听起来很愚蠢,我感觉我的网络在一个迭代步骤中获得了形状为((200,200,3),(),(),(),())的数据。显然,当批处理大小为8时,它会获得((8,10,200,200,3),(8,10,),(8,10,),(8, 10,),(8, 10,))之类的数据。

有人可以向我解释这个问题吗? 确实,我读了很多书,但还是听不懂。

感谢您的帮助:-)

1 个答案:

答案 0 :(得分:0)

对于您在此处描述的模型

input_image = Input(shape=(200, 200, 3), name='input_img')
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')

在喀拉拉邦,fit_generator接受以下两个输入之一:

  1. 张量列表[bsize x 200 x 200 x 3, bsize x 1, bsize x 1]
  2. 张量字典

    {'input_img':bsize x 200 x 200 x3
       'input_ec':bsize x 1,    'input_hc':bsize x 1

如您所见,这与您实际提供的内容完全不同。