使用Dataset API和Keras编写summary.scalar

时间:2019-02-28 11:28:29

标签: python tensorflow keras tensorboard

我使用tensorflow Keras API并尝试将自定义标量添加到tensorboard,但除了损失之外,什么都没有显示。

这是模型的代码:

embedding_in = Embedding(
    input_dim=vocab_size + 1 + 1,  
    output_dim=dim,
    mask_zero=True,
)

embedding_out = Embedding(
    input_dim=vocab_size + 1 + 1,  
    output_dim=dim,
    mask_zero=True,
)

input_a = Input((None,))
input_b = Input((None,))
input_c = Input((None, None))

emb_target = embedding_in(input_a)
emb_context = embedding_out(input_b)
emb_negatives = embedding_out(input_c)

emb_gru = GRU(dim, return_sequences=True)(emb_target)

num_negatives = tf.shape(input_c)[-1]


def make_logits(tensors):
    emb_gru, emb_context, emb_negatives = tensors
    true_logits = tf.reduce_sum(tf.multiply(emb_gru, emb_context), axis=2)
    true_logits = tf.expand_dims(true_logits, -1)
    sampled_logits = tf.squeeze(
        tf.matmul(emb_negatives, tf.expand_dims(emb_gru, axis=2),
                  transpose_b=True), axis=3)
    true_logits = true_logits*0
    sampled_logits = sampled_logits*0

    logits = K.concatenate([true_logits, sampled_logits], axis=-1)
    return logits


logits = Lambda(make_logits)([emb_gru, emb_context, emb_negatives])

mean = tf.reduce_mean(logits)
tf.summary.scalar('mean_logits', mean)

model = keras.models.Model(inputs=[input_a, input_b, input_c], outputs=[logits])

尤其是,我想查看每个批次之后mean_logits标量的演变。

我这样创建和编译模型:

model = build_model(dim, vocab_size)
model.compile(loss='binary_crossentropy', optimizer='sgd')
callbacks = [
        keras.callbacks.TensorBoard(logdir, histogram_freq=1)
]

我将tf Dataset API用于模型:

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:

        sess.run(iterator.initializer)
        sess.run(tf.tables_initializer())
        model.fit(iterator, steps_per_epoch=100, 
                  callbacks=callbacks,
                  validation_data=iterator,
                  validation_steps=1
                 )

但是,我在张量板上没有任何mean_logits图,也没有在图中。 enter image description here

如何在每批处理后在张量板上跟踪mean_logits标量?

我使用tf 1.12和keras 2.1。

1 个答案:

答案 0 :(得分:1)

我也面临同样的问题。看来Keras TensorBoard回调不会自动编写所有现有的摘要,而只会自动编写registered as metrics(并出现在logs字典中)。更新logs对象是一个不错的技巧,因为它允许使用其他回调中的值,请参见Early stopping and learning rate schedule based on custom metric in Keras。我可以看到几种可能性:

1。使用Lambda回调

类似这样的东西:

eval_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: logs.update(
        {'mean_logits': K.eval(mean)}
    ))

2。自定义TensorBoard回调

您还可以将回调子类化,并定义自己的逻辑。例如,我的学习率监控方法:

class Tensorboard(Callback):                                                                                                                                                                                                                                          
    def __init__(self,                                                                                                                                                                                                                                                
                 log_dir='./log',                                                                                                                                                                                                                                     
                 write_graph=True):                                                                                                                                                                                                                                   
        self.write_graph = write_graph                                                                                                                                                                                                                                
        self.log_dir = log_dir                                                                                                                                                                                                                                        

    def set_model(self, model):                                                                                                                                                                                                                                       
        self.model = model                                                                                                                                                                                                                                            
        self.sess = K.get_session()                                                                                                                                                                                                                                   
        if self.write_graph:                                                                                                                                                                                                                                          
            self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)                                                                                                                                                                                        
        else:                                                                                                                                                                                                                                                         
            self.writer = tf.summary.FileWriter(self.log_dir)                                                                                                                                                                                                         

    def on_epoch_end(self, epoch, logs={}):                                                                                                                                                                                                                           
        logs.update({'learning_rate': float(K.get_value(self.model.optimizer.lr))})                                                                                                                                                                                   
        self._write_logs(logs, epoch)                                                                                                                                                                                                                                 

    def _write_logs(self, logs, index):                                                                                                                                                                                                                               
        for name, value in logs.items():                                                                                                                                                                                                                              
            if name in ['batch', 'size']:                                                                                                                                                                                                                             
                continue                                                                                                                                                                                                                                              
            summary = tf.Summary()                                                                                                                                                                                                                                    
            summary_value = summary.value.add()                                                                                                                                                                                                                       
            if isinstance(value, np.ndarray):                                                                                                                                                                                                                         
                summary_value.simple_value = value.item()                                                                                                                                                                                                             
            else:                                                                                                                                                                                                                                                     
                summary_value.simple_value = value                                                                                                                                                                                                                    
            summary_value.tag = name                                                                                                                                                                                                                                  
            self.writer.add_summary(summary, index)                                                                                                                                                                                                                   

        self.writer.flush()                                                                                                                                                                                                                                           

    def on_train_end(self, _):                                                                                                                                                                                                                                        
        self.writer.close() 

在这里,我只是将'learning_rate'明确添加到logs中。但是这种方式可以更加灵活和强大。

3。指标技巧

Here是另一个有趣的解决方法。您需要做的是将自定义指标函数传递给模型的compile()调用,该调用返回汇总的汇总张量。这样做的目的是让Keras将汇总的汇总操作传递给每个session.run调用,并将其结果作为指标返回:

x_entropy_t = K.sum(p_t * K.log(K.epsilon() + p_t), axis=-1, keepdims=True)
full_policy_loss_t = -res_t + X_ENTROPY_BETA * x_entropy_t
tf.summary.scalar("loss_entropy", K.sum(x_entropy_t))
tf.summary.scalar("loss_policy", K.sum(-res_t))
tf.summary.scalar("loss_full", K.sum(full_policy_loss_t))

summary_writer = tf.summary.FileWriter("logs/" + args.name)

def summary(y_true, y_pred):
    return tf.summary.merge_all()

value_policy_model.compile(optimizer=Adagrad(), loss=loss_dict, metrics=[summary])
l = value_policy_model.train_on_batch(x_batch, y_batch)
l_dict = dict(zip(value_policy_model.metrics_names, l))

summary_writer.add_summary(l_dict['value_summary'], global_step=iter_idx)
summary_writer.flush()