具有共享网络的Tensorflow估计器

时间:2017-08-25 04:39:40

标签: python tensorflow

我正在使用新的estimator高级api构建张量流模型。我的模型如下图所示

this

事实上,由于该模型用于模拟游戏操作,因此模型比这更复杂。分类负责决定是否是采取行动的好时机。然后回归将给出有关操作的详细信息。它包含CNN和RNN的组合。

然而,由于复杂性和内存消耗,不可能同时训练和运行分类和回归作为两个网络。另外,当我创建我的估算器时:

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

我只能为估算器提供一个模型函数。是否可以一起训练和运行两个估算器?

1 个答案:

答案 0 :(得分:0)

将损失函数更改为回归和分类损失的线性组合。它将是一个估算器,只有一个损失,但有多个推论。