概念示例:使用backend.function与keras结合使用的MNIST –尝试定义更新的梯度错误

时间:2019-04-15 21:47:34

标签: python tensorflow keras

在下面的示例中,我尝试通过手动定义更新步骤(有时在RL模型中需要)来替换.compile和.fit函数。不管我尝试了什么,都遇到错误,指出在设置“ updates_op” 的行上未定义渐变。

某些策略梯度算法的keras实现(例如A3C或DDQN)会跳过通常与keras一起使用的.compile.fit调用,而是使用具有手动损失定义的keras.backend.function。这样可以将模型输出用作所需变量的变量。

通过这种方式编写MNIST分类器的默认示例,我试图更好地理解这种方法。可悲的是,它在行“ updates_op =” 失败,指出并非所有操作都是可区分的(未定义渐变)。

我很确定这是我的概念上的误解,但我找不到任何解决方案或很好的解释。我将如何帮助培训这个简单的mnist示例,将不胜感激。

我尝试了使用占位符对象和输入对象以及其他方法,但是始终得到相同的结果。

导入并获取数据:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
#get data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

定义层,损失和训练功能:

input_node = tf.keras.Input(shape=(28,28),name = 'input_node')
flat = tf.keras.layers.Flatten(name= 'faltten1')(input_node)
dense1 = tf.keras.layers.Dense(128,activation='relu', name = 'dense1')(flat)
dropout = tf.keras.layers.Dropout(0.2, name = 'dropout1')(dense1)
output_node = tf.keras.layers.Dense(10,activation='softmax',name='out')(dropout)

model = tf.keras.Model(inputs=input_node, outputs=output_node)
model.summary()

#define loss function manually
y_true_obj = tf.keras.Input(shape=(1,),name='y_true_placeholder')
y_pred_obj = tf.keras.Input(shape=(1,),name='y_pred_placeholder')
#also tried: tf.keras.backend.placeholder instead of input
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true_obj,y_pred_obj)
#also tried: loss = tf.keras.backend.mean(tf.keras.backend.square(y_true-y_pred))

#define an update function which we get from an optimizer object
optimizer = tf.keras.optimizers.Adam()
updates_op = optimizer.get_updates(params=model.trainable_weights, loss=loss)

#the next function replaces the model.fit
train_fn = tf.keras.backend.function(inputs=[model.input, y_true_obj,y_pred_obj], outputs=[], updates=updates_op)

现在所有内容都已定义。但是由于不存在model.fit,我们需要手动循环并调用train。

for x_train_this,y_train_this in zip(x_train,y_train):
    y_pred = model.predict(x_train_this)
    train_fn([x_train_this,y_train_this,y_pred])

预计将运行一个训练纪元,但无法解释为updates_op = ...,并带有:

  

ValueError:操作具有None用于渐变。请确保您所有的操作都定义了渐变(即可区分)。没有渐变的常见操作:K.argmax,K.round,K.eval。

0 个答案:

没有答案