使用tf.eager训练复杂的nn模型(更好地使用TF2符号支持)

时间:2019-08-19 23:27:00

标签: python tensorflow tensorflow2.0 eager

是否存在(或多或少)简单的方法来编写复杂的NN模型,以便可以在急切模式下进行训练?有这样的代码示例吗?

例如,我要使用InceptionResnetV2。我有使用tf.contrib.slim创建的代码。根据此链接https://github.com/tensorflow/tensorflow/issues/16182,slim已弃用,我需要使用Keras。而且我真的不能使用苗条的代码进行急切的训练,因为我无法获取变量列表并应用渐变(好的,我可以尝试将模型包装到GradientTape中,但不确定如何处理正则化损失)。

好的,让我们尝试Keras

In [30]: tf.__version__                                                                                                                                                                          
Out[30]: '1.13.1'

In [31]: tf.enable_eager_execution()

In [32]: from keras.applications.inception_resnet_v2 import InceptionResNetV2

In [33]: model = InceptionResNetV2(weights=None)
...
/usr/local/lib/python3.6/dist-packages/keras_applications/inception_resnet_v2.py in InceptionResNetV2(include_top, weights, input_tensor, input_shape, pooling, classes, **kwargs)
    246 
    247     if input_tensor is None:
--> 248         img_input = layers.Input(shape=input_shape)
    249     else:
    250         if not backend.is_keras_tensor(input_tensor):
...
RuntimeError: tf.placeholder() is not compatible with eager execution.

默认情况下不起作用。

在本教程中,他们说我需要建立自己的模型类并自己https://www.tensorflow.org/tutorials/eager/custom_training#define_the_model维护变量。我不确定是否要为Inception做。创建和维护的变量太多。就像在甚至没有苗条的日子里,回到旧版本的TF一样。

在本教程中,网络是使用Keras https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough#create_a_model_using_keras创建的,但是我怀疑通过仅定义模型而不将其与Input一起使用,我是否可以以这种方式轻松维护复杂的结构。例如,在本文中,如果我理解正确,请作者初始化keras Input并将其传播通过模型(如前所述,与Eager一起使用时会导致RuntimeError)。我可以通过将模型类子类化为https://www.tensorflow.org/api_docs/python/tf/keras/Model来建立自己的模型。糟糕,以这种方式,我需要维护图层,而不是变量。在我看来,这几乎是同一个问题。

这里https://www.tensorflow.org/beta/guide/autograph#keras_and_autograph有趣地提到了AutoGrad。它们仅覆盖__call__,因此在这种情况下似乎不需要维护变量,但是我尚未对其进行测试。


那么,有什么简单的解决方案吗?

将纤细模型包裹在GradientTape中吗?然后我该如何对重量进行减重?

我自己跟踪每个变量?听起来有点痛苦。

使用Keras吗?当我在模型中有分支和复杂结构时,如何急切地使用它?

1 个答案:

答案 0 :(得分:2)

您的第一种方法可能是最常见的。错误:

  

RuntimeError:tf.placeholder()与急切执行不兼容。

是因为不能在急切模式下使用tf.placeholder。急于执行时,没有这样的事情的概念。

您可以使用tf.data API为训练数据构建数据集,并将其输入模型。将数据集替换为您的真实数据后会发生以下情况:

import tensorflow as tf
tf.enable_eager_execution()

model = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(weights=None)

model.compile(tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

### Replace with tf.data.Datasets for your actual training data!
train_x = tf.data.Dataset.from_tensor_slices(tf.random.normal((10,299,299,3)))
train_y = tf.data.Dataset.from_tensor_slices(tf.random.uniform((10,), maxval=10, dtype=tf.int32))
training_data = tf.data.Dataset.zip((train_x, train_y)).batch(BATCH_SIZE)

model.fit(training_data)

正如您的标题中所述,此方法也可以在TensorFlow 2.0中使用。