Tensorflow 2:在训练期间获取张量值

时间:2019-12-12 17:40:16

标签: tensorflow keras deep-learning tensorflow-datasets tensorflow2.0

我想在训练期间获得张量的值,以计算表示之间的相互信息;输入和输出。

下面的代码中的get_mutual_information(get_tensors())应该获得张量的值来执行计算。

我编写了另一个函数get_tensors()以使用tf.compat.v1.get_default_graph().get_tensor_by_name()获取值,但是此方法总是重新运行空列表。

您知道在每个时期后获取张量的值吗?谢谢,马哈茂德

这是代码:

def get_tensors():
    tensors = []
    names = []
#     for tensor in tf.compat.v1.get_default_graph().as_graph_def().node: 
#         names.append(tensor.name)
    for op in tf.compat.v1.get_default_graph().get_operations():
        names.append(op.name)
    names = [layer.name for layer in model.layers]
    for name in names:
        tensors.append(tf.compat.v1.get_default_graph().get_tensor_by_name("%s:0" % name))
    return tensors

class Callback_1(tf.keras.callbacks.Callback): 
    def on_train_begin(self, logs={}):
        self.mi_xt_all = []
        self.mi_ty_all = []
        self.epochs = []
    def on_epoch_begin(self, epoch, logs={}):
        print(get_tensors())
        mi_xt, mi_ty = get_mutual_information(get_tensors())
        self.mi_xt_all.append(mi_xt)
        self.mi_ty_all.append(mi_ty)
        self.epochs.append(epoch)
class Callback_2(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.98):
            self.model.stop_training = True
            print("\nReached 99.8% accuracy so cancelling training!")
callbacks_list = [Callback_1(), Callback_2()]

def train_with_mi():
    model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=100, callbacks=callbacks_list)

0 个答案:

没有答案