加载/应用TFX波束变换图时,TFX变换等级不匹配

时间:2018-10-09 21:44:49

标签: tensorflow apache-beam tensorflow-transform

我已经成功地将TFTransformOutput拟合到一些数据(在这种情况下,是UC和TFX示例中常见的UCI人口普查数据集。)我尝试使用带有transform_raw_features(raw_features)方法的转换器。错误:

  

ValueError:节点“ transform / transform / inputs / workclass_copy”具有一个   _output_shapes属性与输出#0的GraphDef不一致:形状必须等于等级,但必须为0和1

在查看源代码时,似乎错误是由于执行此操作时_partially_apply_saved_transform_impl方法中的saved_transform_io引起的。

saver = tf_saver.import_meta_graph(meta_graph_def, import_scope=import_scope,
input_map=input_map)

我检查了由TFX TFTransform和Beam生成的meta_graph_def,并注意到该图确实具有一系列具有输入/输出等级差异的复制变量。但是,这是我无法控制的。

错误消息中的列是“ workclass”,它是一个简单的分类列。我可能做错了什么?调试此问题的最佳方法是什么?至此,我已经对TF源代码进行了深入研究,但错误似乎源于TFTransform图的编写方式,不确定是否需要更改/修复这些杠杆。

enter image description here

这是使用TF Transform v0.9和相应的TF v1.9

  

回溯(最近通话最近):文件   “ /home/sahmed/workspace/ml_playground/TFX-TFT/trainers.py”,第449行,   在parse_csv中       transform_stuff = xformer.transform_raw_features(raw_features)文件   “ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow_transform/output_wrapper.py”,   第122行,位于transform_raw_features中       self.transform_savedmodel_dir,raw_features))文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow_transform/saved/saved_transform_io.py”,   第360行,在partial_apply_saved_transform_internal中       saved_model_dir,逻辑输入图,张量替换图)文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow_transform/saved/saved_saved_transform_io.py”,   第218行,在_partially_apply_saved_transform_impl中       input_map = input_map)文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow/python/training/saver.py”,   1960行,在import_meta_graph中       ** kwargs)文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py”,   import_scoped_meta_graph中的第744行       producer_op_list = producer_op_list)文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py”,   第432行,位于new_func中       返回func(* args,** kwargs)文件“ /home/sahmed/miniconda3/envs/kml2/lib/python2.7/site-packages/tensorflow/python/framework/importer.py”,   import_graph_def中的第422行       引发ValueError(str(e))ValueError:节点'transform / transform / inputs / workclass_copy'具有_output_shapes   属性与输出#0的GraphDef不一致:形状必须为   等级相等,但为0和1

1 个答案:

答案 0 :(得分:1)

问题可能是工作类张量的形状与transform_raw_features期望的不兼容。

TFTransformOutput.transform_raw_features()希望这些功能具有与给tft.AnalyzeDataset()的元数据中描述的特征相同的特征,类似于在本示例中的实现方式: https://github.com/tensorflow/transform/blob/master/examples/simple_example.py#L63

您能否看一下管道中使用的元数据,并确定它与TFTransformOutput.transform_raw_features()中提供的数据兼容?