Tensorflow:将值输入现有占位符或删除它

时间:2018-06-03 11:26:58

标签: tensorflow

我正在实施自定义keras回调,并且我正在同一个模型上执行两个连续的训练阶段。

在回调中,我创建了几个占位符,以便在培训结束时提供一些度量值以进行评估。对于第一个训练阶段,这很好,因为占位符不存在,但是,在第二个训练阶段,这将导致错误,因为tensorflow将创建第二组占位符但具有索引名称。

因此,我正在寻找一种解决方案,可以在第一个培训阶段将值提供给占位符(可能类似于按名称查找占位符,然后将值提供给它)或按名称删除某些占位符,以便我可以创建新的

编辑:

澄清我目前的情况。我实现了这个自定义Keras Callback(我将保留指标的计算):

class Metric(keras.callbacks.Callback):


def __init__(self):
    self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, name="prec")
    tf.summary.scalar("val_precision", self.val_prec_ph)

    self.merged = tf.summary.merge_all()
    self.writer = tf.summary.FileWriter(self.log_dir)

def on_train_begin(self, logs={}):

    self.precision = []

def on_train_end(self, logs={}):

   //do some calculations

    self.precision.append(calculation)

    summary = self.session.run(self.merged,
                               feed_dict={self.val_prec_ph: self.precision[-1]})

    self.writer.add_summary(summary)
    self.writer.flush()

这基本上是我做占位符的框架。由于连续运行tensorflow将执行以下操作: 第一次培训将毫无问题地运行,并将占位符命名为“prec”。但是,在第二次运行中,tensorflow会将self.val_prec_ph占位符命名为“prec_”,这将导致“prec”占位符未被提供的错误,尽管它仍然存在。

因此,我要么直接写入“prec”占位符,要么在第一次运行后删除它,以便我没有重复项。

为什么我在结束训练过程结束时这样做yadda,yadda ...是另一个有另一个问题的故事。

1 个答案:

答案 0 :(得分:0)

以下是针对您的具体问题的可能解决方案,按名称搜索图表中的占位符(使用tf.Graph().get_tensor_by_name()),如果找不到,则创建它:

class Metric(keras.callbacks.Callback):

    def __init__(self, ph_name="prec"):
        try:
            self.val_prec_ph = tf.get_default_graph().get_tensor_by_name(
                ph_name + ':0')
            # Check this solution by @rvinas to cover possible suffix/scope errors:
            # https://stackoverflow.com/a/38935343/624547
        except KeyError:
            self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, 
                                              name=ph_name)

        tf.summary.scalar("val_precision", self.val_prec_ph)

        self.merged = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter(self.log_dir)

    # ...