我有一个seq2seq模型的TF 1.0.1代码。我正在尝试使用Tensorflow Keras重写它。
TF 1.0.1代码具有以下解码器架构:
with tf.variable_scope("decoder_scope") as decoder_scope:
# output projection
# we need to specify output projection manually, because sampled softmax needs to have access to the the projection matrix
output_projection_w_t = tf.get_variable("output_projection_w", [vocabulary_size, state_size], dtype=DTYPE)
output_projection_w = tf.transpose(output_projection_w_t)
output_projection_b = tf.get_variable("output_projection_b", [vocabulary_size], dtype=DTYPE)
decoder_cell = tf.contrib.rnn.LSTMCell(num_units=state_size)
decoder_cell = DtypeDropoutWrapper(cell=decoder_cell, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
decoder_cell = contrib_rnn.MultiRNNCell(cells=[decoder_cell] * num_lstm_layers, state_is_tuple=True)
# define decoder train netowrk
decoder_outputs_tr, _ , _ = dynamic_rnn_decoder(
cell=decoder_cell,
decoder_fn= simple_decoder_fn_train(last_encoder_state, name=None),
inputs=decoder_inputs,
sequence_length=decoder_sequence_lengths,
parallel_iterations=None,
swap_memory=False,
time_major=False)
# define decoder inference network
decoder_scope.reuse_variables()
以下是sampled_softmax_loss的计算方式:
decoder_forward_outputs = tf.reshape(decoder_outputs_tr,[-1, state_size])
decoder_target_labels = tf.reshape(decoder_labels ,[-1, 1]) #decoder_labels is target sequnce of decoder
sampled_softmax_losses = tf.nn.sampled_softmax_loss(
weights = output_projection_w_t,
biases = output_projection_b,
inputs = decoder_forward_outputs,
labels = decoder_target_labels ,
num_sampled = 500,
num_classes=vocabulary_size,
num_true = 1,
)
total_loss_op = tf.reduce_mean(sampled_softmax_losses)
这是我在Keras中使用的解码器:
decoder_inputs = tf.keras.Input(shape=(None,), name='decoder_input')
emb_layer = tf.keras.layers.Embedding(vocabulary_size, state_size)
x_d = emb_layer(decoder_inputs)
d_lstm_layer = tf.keras.layers.LSTM(embed_dim, return_sequences=True)
d_lstm_out = d_lstm_layer(x_d, initial_state=encoder_states)
这是我用于Keras模型的sampled_softmax_loss函数:
class SampledSoftmaxLoss(object):
def __init__(self, model):
self.model = model
output_layer = model.layers[-1]
self.input = output_layer.input
self.weights = output_layer.weights
def loss(self, y_true, y_pred, **kwargs):
loss = tf.nn.sampled_softmax_loss(
weights=self.weights[0],
biases=self.weights[1],
labels=tf.reshape(y_true ,[-1, 1]),
inputs=tf.reshape(d_lstm_out,[-1, state_size]),
num_sampled = 500,
num_classes = vocabulary_size
)
但是,它不起作用。 谁能帮助我在Keras中正确实现sampled_loss_funtion。