Tensorflow Estimator API:如何从输入函数传递参数

时间:2018-01-21 23:29:05

标签: tensorflow tensorflow-estimator

我试图将类权重添加为我的模型的超参数,但是为了计算权重,我需要读取输入数据,这发生在input_fn中,然后传递给estimator.fit()input_fn的输出只是要素,标签应具有相同的形状num_examples * num_features。我的问题 - 有没有办法将数据从input_fn传播到model_fn的超参数映射?或者作为替代方案 - 也许有一个input_fn数据集的包装器,它允许过度采样少数/下采样多数以及批处理 - 在这种情况下,我不需要任何参数来传播。

1 个答案:

答案 0 :(得分:1)

功能和标签都可以是张量词典(不仅仅是一个张量)。张量可以是你想要的任何形状,虽然它常见于num_examples * ...

如果您不使用任何预定义的估算器,最简单的方法是添加另一个功能,计算权重,计算模型中的权重然后使用它们(将损失乘以或将其作为参数)。

您还可以访问input_fn中的超级参数,这样您就可以计算其中的权重并将其作为单独的列添加。

如果您使用预设估算器,请查看文档。我看到他们中的大多数都支持weight_column_name。在这种情况下,只需给它在功能字典中为权重值使用的名称。

或者,如果所有其他方法都失败了,您可以在将数据提供给张量流之前以您希望的方式对数据进行采样。