我正在实施自定义keras回调,并且我正在同一个模型上执行两个连续的训练阶段。
在回调中,我创建了几个占位符,以便在培训结束时提供一些度量值以进行评估。对于第一个训练阶段,这很好,因为占位符不存在,但是,在第二个训练阶段,这将导致错误,因为tensorflow将创建第二组占位符但具有索引名称。
因此,我正在寻找一种解决方案,可以在第一个培训阶段将值提供给占位符(可能类似于按名称查找占位符,然后将值提供给它)或按名称删除某些占位符,以便我可以创建新的
编辑:
澄清我目前的情况。我实现了这个自定义Keras Callback(我将保留指标的计算):
class Metric(keras.callbacks.Callback):
def __init__(self):
self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, name="prec")
tf.summary.scalar("val_precision", self.val_prec_ph)
self.merged = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.log_dir)
def on_train_begin(self, logs={}):
self.precision = []
def on_train_end(self, logs={}):
//do some calculations
self.precision.append(calculation)
summary = self.session.run(self.merged,
feed_dict={self.val_prec_ph: self.precision[-1]})
self.writer.add_summary(summary)
self.writer.flush()
这基本上是我做占位符的框架。由于连续运行tensorflow将执行以下操作: 第一次培训将毫无问题地运行,并将占位符命名为“prec”。但是,在第二次运行中,tensorflow会将self.val_prec_ph占位符命名为“prec_”,这将导致“prec”占位符未被提供的错误,尽管它仍然存在。
因此,我要么直接写入“prec”占位符,要么在第一次运行后删除它,以便我没有重复项。
为什么我在结束训练过程结束时这样做yadda,yadda ...是另一个有另一个问题的故事。
答案 0 :(得分:0)
以下是针对您的具体问题的可能解决方案,按名称搜索图表中的占位符(使用tf.Graph().get_tensor_by_name()
),如果找不到,则创建它:
class Metric(keras.callbacks.Callback):
def __init__(self, ph_name="prec"):
try:
self.val_prec_ph = tf.get_default_graph().get_tensor_by_name(
ph_name + ':0')
# Check this solution by @rvinas to cover possible suffix/scope errors:
# https://stackoverflow.com/a/38935343/624547
except KeyError:
self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64,
name=ph_name)
tf.summary.scalar("val_precision", self.val_prec_ph)
self.merged = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.log_dir)
# ...