在下面的示例中,我尝试通过手动定义更新步骤(有时在RL模型中需要)来替换.compile和.fit函数。不管我尝试了什么,都遇到错误,指出在设置“ updates_op” 的行上未定义渐变。
某些策略梯度算法的keras实现(例如A3C或DDQN)会跳过通常与keras一起使用的.compile
和.fit
调用,而是使用具有手动损失定义的keras.backend.function
。这样可以将模型输出用作所需变量的变量。
通过这种方式编写MNIST分类器的默认示例,我试图更好地理解这种方法。可悲的是,它在行“ updates_op =” 失败,指出并非所有操作都是可区分的(未定义渐变)。
我很确定这是我的概念上的误解,但我找不到任何解决方案或很好的解释。我将如何帮助培训这个简单的mnist示例,将不胜感激。
我尝试了使用占位符对象和输入对象以及其他方法,但是始终得到相同的结果。
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
#get data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
input_node = tf.keras.Input(shape=(28,28),name = 'input_node')
flat = tf.keras.layers.Flatten(name= 'faltten1')(input_node)
dense1 = tf.keras.layers.Dense(128,activation='relu', name = 'dense1')(flat)
dropout = tf.keras.layers.Dropout(0.2, name = 'dropout1')(dense1)
output_node = tf.keras.layers.Dense(10,activation='softmax',name='out')(dropout)
model = tf.keras.Model(inputs=input_node, outputs=output_node)
model.summary()
#define loss function manually
y_true_obj = tf.keras.Input(shape=(1,),name='y_true_placeholder')
y_pred_obj = tf.keras.Input(shape=(1,),name='y_pred_placeholder')
#also tried: tf.keras.backend.placeholder instead of input
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true_obj,y_pred_obj)
#also tried: loss = tf.keras.backend.mean(tf.keras.backend.square(y_true-y_pred))
#define an update function which we get from an optimizer object
optimizer = tf.keras.optimizers.Adam()
updates_op = optimizer.get_updates(params=model.trainable_weights, loss=loss)
#the next function replaces the model.fit
train_fn = tf.keras.backend.function(inputs=[model.input, y_true_obj,y_pred_obj], outputs=[], updates=updates_op)
现在所有内容都已定义。但是由于不存在model.fit,我们需要手动循环并调用train。
for x_train_this,y_train_this in zip(x_train,y_train):
y_pred = model.predict(x_train_this)
train_fn([x_train_this,y_train_this,y_pred])
预计将运行一个训练纪元,但无法解释为updates_op = ...,并带有:
ValueError:操作具有
None
用于渐变。请确保您所有的操作都定义了渐变(即可区分)。没有渐变的常见操作:K.argmax,K.round,K.eval。