我想和Keras一起训练我的ConvNet。在完成一些教程之后,我写了这样的内容。
我不知道它是否很好,特别是我对使用生成器训练模型的方法有一些疑问。
之前,我用numpy-array生成器喂训练过程,但是我读到可以使用tfrecords来提高性能。
我第一次在下面的create_dataset
函数中(在“屈服”它们之前)将张量转换为numpy数组,但后来我读到了
确实有一种更有效的方式来使用数据集,而无需 将张量转换为numpy数组。
所以我试图用这种方式编辑我的代码
input_image=tf.keras.Input(tensor=x)
和model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y])
。
在我既未使用target_tensors
内的model.compile
也不使用tensor=x
内的tf.keras.Input
之前(我只指定了输入形状)。
import tensorflow as tf
import keras
import compute_loss #my loss function
dataset_train_path="dataset_train.tfrecords"
dataset_val_path="dataset_val.tfrecords"
filepath_checkpoint="weights-best.hdf5"
Adam=tf.keras.optimizers.Adam
optimizer = Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
BATCH_SIZE=32
TRAINING_SIZE=5717
VALIDATION_SIZE=5823
TRAINING_STEPS=TRAINING_SIZE//BATCH_SIZE
VALIDATION_STEPS=VALIDATION_SIZE//BATCH_SIZE
"""-----------------Here I define my generator-----------------"""
def _parse_function(proto):
keys_to_features = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(proto, keys_to_features)
parsed_features['image'] = tf.decode_raw(parsed_features['image'], tf.float16)
parsed_features['label'] = tf.decode_raw(parsed_features['label'], tf.float16)
return parsed_features['image'], parsed_features["label"]
def create_dataset(filepath, batch_size=BATCH_SIZE):
dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(_parse_function, num_parallel_calls=8)
dataset = dataset.repeat()
dataset = dataset.shuffle(100)
dataset = dataset.batch(BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
image, label = iterator.get_next()
image = tf.reshape(image, [BATCH_SIZE, 416, 416, 3])
label = tf.reshape(label, [BATCH_SIZE, 75, 25])
while True:
yield image, label
"""-----------------Here I create my train/val generators-----------------"""
training_generator=create_dataset(dataset_train_path)
validation_generator=create_dataset(dataset_val_path)
"""-----------------Now I can define my model-----------------"""
x,y=next(training_generator);
def net():
input_image=tf.keras.Input(tensor=x)
inputs=tf.keras.layers.Conv2D(16,3,padding='same', activation='relu', name='conv_1')(input_image)
inputs=tf.keras.layers.BatchNormalization(name='norm_1')(inputs)
...
...
outputs = tf.keras.layers.Conv2D(75, 1, name='conv_13')(inputs)
model = tf.keras.Model(inputs=input_image, outputs=outputs)
return model
if __name__ == '__main__':
model=net()
model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y])
model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=1000, max_queue_size=1000, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)
现在培训进行得很快,但我怀疑某些地方存在错误。你能帮我吗?
编辑:如果我将数据集直接放在fit_generator
中,则会得到以下信息:
>>> train(model, DataGenerator, filepath_checkpoint="weights-best-tiny-test.hdf5")
Epoch 1/5000
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "tiny.py", line 239, in train
model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=5000, max_queue_size=100, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1586, in fit_generator
steps_name='steps_per_epoch')
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 211, in model_iteration
batch_data = _get_next_batch(output_generator, mode)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 323, in _get_next_batch
generator_output = next(output_generator)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 767, in get
six.reraise(*sys.exc_info())
File "C:\Program Files\Python35\lib\site-packages\six.py", line 693, in reraise
raise value
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 743, in get
inputs = self.queue.get(block=True).get()
File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 644, in get
raise self._value
File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 680, in next_sample
return six.next(_SHARED_SEQUENCES[uid])
TypeError: 'Iterator' object is not an iterator