我使用以下代码(tensorflow == 1.14)构建模型:
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.embedding = tf.keras.layers.Embedding(10, 5)
self.rnn = tf.keras.layers.GRU(100) # neither GRU nor LSTM works
self.final_layer = tf.keras.layers.Dense(10)
self.loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def call(self, inp):
inp_em = self.embedding(inp) # (batch_size, seq_len, embedding_size)
inp_enc = self.rnn(inp_em) # (batch_size, hidden_size)
logits = self.final_layer(inp_enc) # (batch_size, class_num)
return logits
model = Model()
inp = np.random.randint(0, 10, [5, 50], dtype=np.int32)
out = np.random.randint(0, 10, [5], dtype=np.int32)
with tf.GradientTape() as tape:
logits = model(inp)
loss = model.loss_obj(out, logits)
print(loss)
gradients = tape.gradient(tf.reduce_mean(loss), model.trainable_variables)
print('========== Trainable Variables ==========')
for v in model.trainable_variables:
print(v)
print('========== Gradients ==========')
for g in gradients:
print(g)
但是当我打印网格物体时,输出为:
Tensor("categorical_crossentropy/weighted_loss/Mul:0", shape=(5,), dtype=float32)
========== Trainable Variables ==========
<tf.Variable 'model/embedding/embeddings:0' shape=(10, 5) dtype=float32>
<tf.Variable 'model/gru/kernel:0' shape=(5, 300) dtype=float32>
<tf.Variable 'model/gru/recurrent_kernel:0' shape=(100, 300) dtype=float32>
<tf.Variable 'model/gru/bias:0' shape=(300,) dtype=float32>
<tf.Variable 'model/dense/kernel:0' shape=(100, 10) dtype=float32>
<tf.Variable 'model/dense/bias:0' shape=(10,) dtype=float32>
========== Gradients ==========
None
None
None
None
Tensor("MatMul:0", shape=(100, 10), dtype=float32)
Tensor("BiasAddGrad:0", shape=(10,), dtype=float32)
用于最后一层的网格物体工作良好,但是对于GRU层则没有,等等。
我已经尝试过tf.keras.layers.LSTM
和tf.keras.layers.GRU
,但存在相同的问题。
最后,我将tf.GradientTape().gradient()
替换为tf.graidents()
:
logits = model(inp)
loss = model.loss_obj(out, logits)
gradients = tf.gradients(tf.reduce_mean(loss), model.trainable_variables)
渐变起作用。但是我仍然不知道这两种工具有什么区别。