Keras渴望执行

时间:2018-04-22 16:17:16

标签: python tensorflow keras

Keras的TensorFlow后端是否依赖于急切的执行?

如果不是这样的话:

我可以基于Keras和Tensorflow操作构建TensorFlow图,然后使用Keras高级API训练整个模型吗?

1 个答案:

答案 0 :(得分:6)

  

这是出于研究目的,我不能在这里介绍。

这使得回答你的问题变得非常困难。如果你能找到一个与你的研究无关的玩具例子 - 你想要什么,我们会尝试从那里建造一些东西会更好。

  

Keras的TensorFlow后端是否依赖于急切的执行?

不,它没有。 Keras是在急切的执行介绍之前建立的。然而,Keras(tf中的那个)可以在急切的执行模式下工作(参见fchollet的answer)。

  

我可以构建TensorFlow图并将其与Keras模型相结合,然后使用Keras高级API联合训练吗?

我不确定“构建TensorFlow图”是什么意思,因为无论何时使用keras,图形都已存在。如果您正在谈论在现有图表中添加一堆操作,那么它肯定是可能的。你只需要用Lambda图层来包装它,就像你在符号模式下使用Keras一样:

import tensorflow as tf
from sacred import Experiment

ex = Experiment('test-18')

tf.enable_eager_execution()


@ex.config
def my_config():
    pass


@ex.automain
def main():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    x_train, x_test = (e.reshape(e.shape[0], -1) for e in (x_train, x_test))
    y_train, y_test = (tf.keras.utils.to_categorical(e) for e in (y_train, y_test))

    def complex_tf_fn(x):
        u, v = tf.nn.moments(x, axes=[1], keep_dims=True)
        return (x - u) / tf.sqrt(v)

    with tf.device('/cpu:0'):
        model = tf.keras.Sequential([
            tf.keras.layers.Lambda(complex_tf_fn, input_shape=[784]),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Lambda(complex_tf_fn),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        model.compile(optimizer=tf.train.AdamOptimizer(),
                      loss='categorical_crossentropy')

        model.fit(x_train, y_train,
                  epochs=10,
                  validation_data=(x_test, y_test),
                  batch_size=1024,
                  verbose=2)
python test-18.py with seed=21

INFO - test-18 - Running command 'main'
INFO - test-18 - Started
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
 - 9s - loss: 3.4012 - val_loss: 1.3575
Epoch 2/10
 - 9s - loss: 0.9870 - val_loss: 0.7270
Epoch 3/10
 - 9s - loss: 0.6097 - val_loss: 0.6071
Epoch 4/10
 - 9s - loss: 0.4459 - val_loss: 0.4824
Epoch 5/10
 - 9s - loss: 0.3352 - val_loss: 0.4436
Epoch 6/10
 - 9s - loss: 0.2661 - val_loss: 0.3997
Epoch 7/10
 - 9s - loss: 0.2205 - val_loss: 0.4048
Epoch 8/10
 - 9s - loss: 0.1877 - val_loss: 0.3788
Epoch 9/10
 - 9s - loss: 0.1511 - val_loss: 0.3506
Epoch 10/10
 - 9s - loss: 0.1304 - val_loss: 0.3330
INFO - test-18 - Completed after 0:01:31

Process finished with exit code 0