我正在尝试根据自定义Keras模型构建TensorFlow2估算器。该模型将形状为[batch_size,n,h,w,c]的张量作为输入。我需要从背面在每个[n,h,w,c]张量上应用CNN。为此,我正在使用tf.map_fn:
make_model(params):
batch = Input(shape=[n, h, w, c], batch_size=params.batch_size, name='inputs')
feature_extraction = SomeCustomLayer()
x = tf.map_fn(feature_extraction, batch)
...
softmax_score = softmax(x)
return tf.keras.Model(inputs=batch, outputs=softmax_score, name='custom_model')
当我编译并将模型转换为估计量时,一切运行良好:
model = make_model(params)
model.compile(optimizer=optimizer, loss=loss_function, metrics=metrics_list)
estimator = tf.keras.estimator.model_to_estimator(milcnn)
但是,当我开始训练时,它却惨败:
training_log = estimator.train(input_fn=lambda: training_dataset)
...
WARNING:tensorflow:The graph (<tensorflow.python.framework.ops.Graph object at 0x7fa5ebbdf6d0>) of the iterator is different from the graph (<tensorflow.python.framework.ops.Graph object at 0x7fa618050910>) the dataset: tf.Tensor(<unprintable>, shape=(), dtype=variant) was created in. If you are using the Estimator API, make sure that no part of the dataset returned by the `input_fn` function is defined outside the `input_fn` function. Please ensure that all datasets in the pipeline are created in the same graph as the iterator. NOTE: This warning will become an error in future versions of TensorFlow.
...
Traceback (most recent call last):
File "/opt/anaconda3/envs/direx/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 2104, in make_initializable_iterator
return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access
AttributeError: 'BatchDataset' object has no attribute '_make_initializable_iterator'
During handling of the above exception, another exception occurred:
...
RuntimeError: Attempting to capture an EagerTensor without building a function.
我在这个阶段很困惑。当我直接将其用作Keras模型时,我的数据集可以与模型完美配合。因此,我希望它在Estimator接口中也有效。问题确实是由于滥用估计器input_fn造成的,还是由于我构建估计器或Keras模型的方式造成的?
答案 0 :(得分:1)
我发现了问题。在训练循环之前,我正在初始化我的数据集:
dataset = input_fn(params)
estimator.train(input_fn=lambda: training_dataset)
实际上,您必须直接将input_fn作为参数传递:
estimator.train(input_fn=lambda: input_fn(params))