Tensorflow变换执行期间的变化图

时间:2019-06-18 23:06:46

标签: python tensorflow apache-beam tensorflow-transform

我正在使用tensorflow-transform处理以下代码:

import tensorflow as tf
import apache_beam as beam
import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.coders import example_proto_coder
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
import tempfile


import ast
import six
import apache_beam as beam
# Beam Pipelines must receive a set of config options to set how it should run.
from apache_beam.options.pipeline_options import PipelineOptions

assert six.PY2

options = {
    'runner': 'DirectRunner'
}

pipeline_options = PipelineOptions(**options)

RAW_DATA_SCHEMA = {
    'customer_id': dataset_schema.ColumnSchema(tf.string, [], dataset_schema.ListColumnRepresentation())
}
RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(dataset_schema.Schema(RAW_DATA_SCHEMA))


def preprocess_fn(dictrow):
    return {
        'customer_id': tft.string_to_int(dictrow['customer_id'], vocab_filename='vocab_result')
    }

working_dir = tempfile.mkdtemp(dir='/tmp/')

with beam.Pipeline(options=pipeline_options) as pipeline:
    with beam_impl.Context(tempfile.mkdtemp()):
        raw_data = (
            pipeline
            | 'create' >> beam.Create([
                {'customer_id': ['customer_0']},
                {'customer_id': ['customer1', 'customer2']},
                {'customer_id': ['customer_0']}
            ])
        )

        raw_dataset = (raw_data, RAW_DATA_METADATA)
        transformed_dataset, transform_fn = (
            raw_dataset | beam_impl.AnalyzeAndTransformDataset(preprocess_fn))

        transformed_data, transformed_metadata = transformed_dataset

        OUTPUT_SCHEMA = {
            'customer_id': dataset_schema.ColumnSchema(tf.int64, [], dataset_schema.ListColumnRepresentation())
        }

        _ = transformed_data | 'writing' >> beam.io.tfrecordio.WriteToTFRecord(
            working_dir + '/tf', coder=example_proto_coder.ExampleProtoCoder(dataset_schema.Schema(OUTPUT_SCHEMA)))

        pipeline.run().wait_until_finish()

但这给了我错误:

  

tftrec / local / lib / python2.7 / site-packages / tensorflow_transform / beam / impl.pyc   处理中(自身,批处理,saved_model_dir)       438#在本DoFn的整个生命周期中,无论情况如何,这一点都应保持不变       439是否已缓存self._graph_state。   -> 440断言self._graph_state.saved_model_dir == saved_model_dir       441       442产生self._handle_batch(batch)

     

RuntimeError:AssertionError [运行时   'AnalyzeAndTransformDataset / TransformDataset / Transform']

我想知道可能是什么问题。我尝试打印self._graph_state.saved_model_dir,但确实可以更改,但不确定为什么会发生。

在Python2上运行tensorflow==1.13.1tensorflow-transform==0.13.0apache_beam[gcp]==2.11

尽管有时会抛出错误,但对转换版本0.8的测试似乎更加稳定。

看起来,删除tfrecord文件的最终写入过程似乎可以使代码稳定。

0 个答案:

没有答案