发电机在错误的时间调用(keras)

时间:2017-07-25 12:37:00

标签: tensorflow keras

我在keras 2.0.2中使用fit_generator()批量大小为10和步骤320,因为我有3209个样本用于培训。在第一个纪元开始之前,发电机被称为11次,显示:

Train -- get ind: 0 to 10
    ...    
Train -- get ind: 100 to 110

然后,在第一批(1/320)之后,它打印出Train -- get ind: 110 to 120,但我认为它应该是Train -- get ind: 0 to 10。我对train_generator()函数的实现是否不正确?或者为什么我有这个问题?

这是我的生成器代码:

EPOCH = 10
x_train_img = img[:train_size] # shape: (3209,512,512)
x_test_img = img[train_size:]  # shape: (357,512,512)

def train_generator():
    global x_train_img

    last_ind = 0

    while 1:
        x_train = x_train_img[last_ind:last_ind+BATCH_SIZE]
        print('Train -- get ind: ',last_ind," to ",last_ind+BATCH_SIZE)
        last_ind = last_ind+BATCH_SIZE
        x_train = x_train.astype('float32') / 255.
        x_train = np.reshape(x_train, (len(x_train), 512, 512, 1)) 
        yield (x_train, x_train)
        if last_ind >= x_train_img.shape[0]:
             last_ind = 0

def test_generator():
     ...

train_steps = x_train_img.shape[0]//BATCH_SIZE #320
test_steps = x_test_img.shape[0]//BATCH_SIZE   #35

autoencoder.fit_generator(train_generator(), 
                steps_per_epoch=train_steps, 
                epochs=EPOCH,
                validation_data=test_generator(),
                validation_steps=test_steps,
                callbacks=[csv_logger] )

更好?编写生成器的方式:

def train_generator():
    global x_train_img

    while 1:
        for i in range(0, x_train_img.shape[0], BATCH_SIZE):
            x_train = x_train_img[i:i+BATCH_SIZE]
            print('Train -- get ind: ',i," to ",i+BATCH_SIZE)
            x_train = x_train.astype('float32') / 255.
            x_train = np.reshape(x_train, (len(x_train), 512, 512, 1)) 
            yield (x_train, x_train)

1 个答案:

答案 0 :(得分:2)

默认情况下,gem install bcrypt-ruby -v '3.0.1' 使用fit_generator()。 所以你观察到的是:

  1. 在纪元开始之前,您的生成器会产生10个批次以填充队列。这是0到100的样本。
  2. 然后,纪元开始,并从队列中弹出一个批处理以进行模式拟合。
  3. 生成器生成一个新批处理以填充队列中的空白区域。那个样本100到110。
  4. 然后,更新进度条。屏幕上会显示进度max_queue_size=10
  5. 再次执行步骤2和3,打印1/320
  6. 因此这种模型拟合程序没有任何问题。生成的第一批确实是第一批用于拟合模型的批次。它只是隐藏在它后面的队列,并且在第一次模型更新发生之前,生成器被多次调用以填满队列。