我正在使用张量流集线器,张量流估计器和张量流数据建立模型分类。
我的训练函数正在返回数据集,model_fn
的定义如下:
def train_input_fn():
return dataset_input_fn(DATASET_TRAIN_PATH)
def model_fn(features, labels, mode, params):
logging.info("model_fn")
# module is imported from tf-hub
return head.create_estimator_spec (features, mode, ...)
非常类似于code by Damien。
代码环境为:Python 2,Google云数据实验室,tf.version
为1.12。
引发的错误是model_fn
不需要标签参数(可能由tf-data
数据集生成)。假设model_fn
返回数据集,input_fn
的签名应该是什么?
请提出任何建议。
非常感谢,
埃拉兰语