我无法实现使用Tensorflow的Estimator API进行摘要。
Estimator类非常有用,原因很多:我已经实现了自己的类,它们非常相似,但我试图切换到这个类。
以下是代码示例:
import tensorflow as tf
import tensorflow.contrib.layers as layers
import tensorflow.contrib.learn as learn
import numpy as np
# To reproduce the error: docker run --rm -w /algo -v $(pwd):/algo tensorflow/tensorflow bash -c "python sample.py"
def model_fn(x, y, mode):
logits = layers.fully_connected(x, 12, scope="dense-1")
logits = layers.fully_connected(logits, 56, scope="dense-2")
logits = layers.fully_connected(logits, 4, scope="dense-3")
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y), name="xentropy")
return {"predictions":logits}, loss, tf.train.AdamOptimizer(0.001).minimize(loss)
def input_fun():
""" To be completed for a 4 classes classification problem """
feature = tf.constant(np.random.rand(100,10))
labels = tf.constant(np.random.random_integers(0,3, size=(100,)))
return feature, labels
estimator = learn.Estimator(model_fn=model_fn, )
trainingConfig = tf.contrib.learn.RunConfig(save_checkpoints_secs=60)
estimator = learn.Estimator(model_fn=model_fn, model_dir="./tmp", config=trainingConfig)
# Works
estimator.fit(input_fn=input_fun, steps=2)
# The following code does not work
# Can't initialize saver
# saver = tf.train.Saver(max_to_keep=10) # Error: No variables to save
# The following fails because I am missing a saver... :(
hooks=[
tf.train.LoggingTensorHook(["xentropy"], every_n_iter=100),
tf.train.CheckpointSaverHook("./tmp", save_steps=1000, checkpoint_basename='model.ckpt'),
tf.train.StepCounterHook(every_n_steps=100, output_dir="./tmp"),
tf.train.SummarySaverHook(save_steps=100, output_dir="./tmp"),
]
estimator.fit(input_fn=input_fun, steps=2, monitors=hooks)
正如您所看到的,我可以创建一个Estimator并使用它,但我可以实现为拟合过程添加钩子。
日志记录挂钩工作得很好但其他人需要张量和保护程序,我无法提供。
张量在模型函数中定义,因此我无法将它们传递给 SummaryHook ,并且 Saver 无法初始化,因为没有张量到保存...
我的问题有解决方案吗? (我猜是的,但在tensorflow文档中缺少此部分的文档)
提前致谢。
PS:我已经看过DNNClassifier API,但我想使用Convolutional Nets和其他人的估算器API。我需要为任何估算器创建摘要。
答案 0 :(得分:10)
预期用例是您让Estimator为您保存摘要。 RunConfig中有用于配置摘要写入的选项。在constructing the Estimator时传递RunConfigs。
答案 1 :(得分:0)
在tf.summary.scalar("loss", loss)
中只有model_fn
,然后运行没有summary_hook
的代码。损耗被记录并显示在张量板上。
另请参阅: