我正在使用tensorflow-transform处理以下代码:
import tensorflow as tf
import apache_beam as beam
import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.coders import example_proto_coder
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
import tempfile
import ast
import six
import apache_beam as beam
# Beam Pipelines must receive a set of config options to set how it should run.
from apache_beam.options.pipeline_options import PipelineOptions
assert six.PY2
options = {
'runner': 'DirectRunner'
}
pipeline_options = PipelineOptions(**options)
RAW_DATA_SCHEMA = {
'customer_id': dataset_schema.ColumnSchema(tf.string, [], dataset_schema.ListColumnRepresentation())
}
RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(dataset_schema.Schema(RAW_DATA_SCHEMA))
def preprocess_fn(dictrow):
return {
'customer_id': tft.string_to_int(dictrow['customer_id'], vocab_filename='vocab_result')
}
working_dir = tempfile.mkdtemp(dir='/tmp/')
with beam.Pipeline(options=pipeline_options) as pipeline:
with beam_impl.Context(tempfile.mkdtemp()):
raw_data = (
pipeline
| 'create' >> beam.Create([
{'customer_id': ['customer_0']},
{'customer_id': ['customer1', 'customer2']},
{'customer_id': ['customer_0']}
])
)
raw_dataset = (raw_data, RAW_DATA_METADATA)
transformed_dataset, transform_fn = (
raw_dataset | beam_impl.AnalyzeAndTransformDataset(preprocess_fn))
transformed_data, transformed_metadata = transformed_dataset
OUTPUT_SCHEMA = {
'customer_id': dataset_schema.ColumnSchema(tf.int64, [], dataset_schema.ListColumnRepresentation())
}
_ = transformed_data | 'writing' >> beam.io.tfrecordio.WriteToTFRecord(
working_dir + '/tf', coder=example_proto_coder.ExampleProtoCoder(dataset_schema.Schema(OUTPUT_SCHEMA)))
pipeline.run().wait_until_finish()
但这给了我错误:
tftrec / local / lib / python2.7 / site-packages / tensorflow_transform / beam / impl.pyc 处理中(自身,批处理,saved_model_dir) 438#在本DoFn的整个生命周期中,无论情况如何,这一点都应保持不变 439是否已缓存self._graph_state。 -> 440断言self._graph_state.saved_model_dir == saved_model_dir 441 442产生self._handle_batch(batch)
RuntimeError:AssertionError [运行时 'AnalyzeAndTransformDataset / TransformDataset / Transform']
我想知道可能是什么问题。我尝试打印self._graph_state.saved_model_dir
,但确实可以更改,但不确定为什么会发生。
在Python2上运行tensorflow==1.13.1
,tensorflow-transform==0.13.0
,apache_beam[gcp]==2.11
尽管有时会抛出错误,但对转换版本0.8的测试似乎更加稳定。
看起来,删除tfrecord文件的最终写入过程似乎可以使代码稳定。