我正在尝试使用TFRecord数据集作为Keras模型的输入。似乎网络开始训练,但随后出现错误消息。以下是我用来构建和拟合模型的代码:
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = keras.models.Sequential()
model.add(keras.layers.Conv3D(64, (7,7,7), strides=(2,2,2), padding="same",
use_bias=False, input_shape=[91,109,91,1]))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.MaxPool3D(pool_size=(3,3,3), strides=(2,2,2), padding="same"))
prev_filters = 64
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
strides = 1 if filters == prev_filters else 2
model.add(ResidualUnit(filters, strides=strides))
prev_filters = filters
model.add(keras.layers.GlobalAveragePooling3D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1, activation="sigmoid"))
model = keras.utils.multi_gpu_model(model, gpus=2)
es = keras.callbacks.EarlyStopping(monitor='val_accuracy', mode='auto', patience=20)
mc = keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', save_best_only=True)
tb = keras.callbacks.TensorBoard(log_dir='./logs', write_images=True, write_graph=True)
model.compile(loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy'])
training_set = train_input_fn('train.tfrecords', batch_size=BATCH_SIZE, num_epochs=N_EPOCHS)
validation_set = validation_input_fn('test.tfrecords', batch_size=BATCH_SIZE)
history = model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH_TRAINING,
epochs=N_EPOCHS, validation_data=validation_set,
validation_steps=STEPS_PER_EPOCH_VALIDATION, callbacks=[tb, es, mc])
这是输出(没有有关分发的标准输出):
Train for 15 steps, validate for 3 steps
WARNING: Logging before flag parsing goes to stderr.
W1028 12:25:08.884782 140602772264768 summary_ops_v2.py:1110] Model failed to serialize as JSON. Ignoring... Layers with arguments in `__init__` must override `get_config`.
Epoch 1/1000
2019-10-28 12:25:29.438754: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2019-10-28 12:25:30.317121: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-10-28 12:25:34.050879: I tensorflow/core/profiler/lib/profiler_session.cc:184] Profiler session started.
2019-10-28 12:25:34.053985: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcupti.so.10.0
1/15 [=>............................] - ETA: 6:04 - loss: 0.6008 - accuracy: 0.75002019-10-28 12:25:35.924105: I tensorflow/core/platform/default/device_tracer.cc:588] Collecting 6788 kernel records, 994 memcpy records.
W1028 12:25:39.252643 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (2.169159). Check your callbacks.
2/15 [===>..........................] - ETA: 3:17 - loss: 0.7089 - accuracy: 0.7969W1028 12:25:39.688849 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.885054). Check your callbacks.
3/15 [=====>........................] - ETA: 2:03 - loss: 0.7249 - accuracy: 0.8333W1028 12:25:40.090722 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.442553). Check your callbacks.
4/15 [=======>......................] - ETA: 1:25 - loss: 0.7012 - accuracy: 0.77342019-10-28 12:25:40.205127: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Key: train/image. Can't parse serialized Example.
7/15 [=============>................] - ETA: 37s - loss: 0.6582 - accuracy: 0.75452019-10-28 12:25:41.299845: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[replica_1/Fill_32/_947]]
2019-10-28 12:25:41.299846: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[OptionalHasValue/_10]]
2019-10-28 12:25:41.300453: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
W1028 12:25:41.317599 140602772264768 callbacks.py:1250] Early stopping conditioned on metric `val_accuracy` which is not available. Available metrics are: loss,accuracy
W1028 12:25:41.317952 140602772264768 callbacks.py:990] Can save best model only with val_accuracy available, skipping.
Traceback (most recent call last):
File "DAT_resnet34.py", line 80, in <module>
validation_steps=STEPS_PER_EPOCH_VALIDATION, callbacks=[tb, es, mc])
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
total_epochs=epochs)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 487, in _call
return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1823, in __call__
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1141, in _filtered_call
self.captured_inputs)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: 3 root error(s) found.
(0) Invalid argument: Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[OptionalHasValue/_10]]
(1) Invalid argument: Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[replica_1/Fill_32/_947]]
(2) Invalid argument: Key: train/image. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
0 successful operations.
0 derived errors ignored. [Op:__inference_distributed_function_40587]
Function call stack:
distributed_function -> distributed_function -> distributed_function -> distributed_function -> distributed_function -> distributed_function
对于很长的错误消息,我深表歉意,只是不确定其中的哪一部分是相关的。我该如何解决这个问题?
更新:
我能够(辛苦地)发现TFRecord文件中的样本没有被正确读取来解决此问题。我不确定为什么会这样,并且没有错误消息。有没有一种方法可以检查TFRecord文件是否包含任何损坏的/无法使用的样本并将其删除?
答案 0 :(得分:0)
请参考输出中的第3行:
Layers with arguments in `__init__` must override `get_config`.
这可能与您实现ResidualUnit
的方式有关。
请参见https://stackoverflow.com/a/58799021/4960855中的解决方案
如果您遇到困难,请分享您的实现,我会尽力提供帮助。
它不一定能清除其他错误,但还是值得纠正的。