我的tensorflow版本是1.8。
简单地说,当我尝试使用tf.contrib.learn.KMeansClustering.export_savedmodel()方法时,出现此错误。
doc表明此方法实际上有一个参数graph_rewrite_specs。
我的代码基本上是:
kmeans = tf.contrib.factorization.KMeansClustering(num_clusters=num_clusters,
model_dir='model',
use_mini_batch=True,
mini_batch_steps_per_iteration=1)
# training code here...
kmeans.export_savedmodel('saved',
serving_input_receiver_fn,
as_text=False,
graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,tag_constants.TRAINING), ()),))
我的代码有问题吗?
答案 0 :(得分:0)
您正在检查错误的文档,这些文档是DnnLinearRegressor,而您正在使用具有方法签名的KmeanClustering 的
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False)
DnnLinearRegressor和KmeanClustering继承自Estimator,但Dnn覆盖了export_savedmodel,而KmeanClustering保留了从Estimator继承的原始函数签名