TensorFlow 1.10+:使用子流程预处理TFRecordDataset吗?

时间:2019-05-31 13:50:23

标签: python python-3.x tensorflow tensorflow-estimator

注意:我现在在Stack Overflow的TFRecords上拥有一个迷你系列,并带有相关的colab笔记本。与这篇文章相关的是:

因此,假设您有一些用分隔符分隔的文件(例如TSV),其中每行都已转换为一条记录。此类文件的示例可能如下所示:

// my_file.tsv
...
banana  2  true  false  false  3
...

对于TFRecord,它已相应地转换为适当的值(int / float /字节字符串):

// my_file.tsv
...
b'banana'  2.0  1  0  0  3.0
...

此外,行(记录)的内容是命令行过程的参数:

# this is all bad on-the-spot example
some-bash-command --fruit=banana --number=2 --ripe=true --for-smoothie=false  --for-ice-cream=false --days-old=3

此功能可能会将这种压缩的数据存储方式转变为需要的作为输入。

使用subprocesspopen存在此命令的python接口 例如

def process(command:list, stdin:str, popen_options={}):
    '''
    Arguments:
        command (list): a list of strings indicating the command and its
            arguments to spawn as a subprocess.

        stdin (str): passed as stdin to the subprocess. Assumed to be utf-8
            encoded.

        popen_options (dict): used to configure the subprocess.Popen command

    Returns:
        stdout, stderr
    '''
    command = clean_command(command)
    popen_config = POPEN_DEFAULTS.copy()
    popen_config.update(popen_options)
    try:
        pid = subprocess.Popen(args = command, **popen_config)
        stdout_data, stderr_data = pid.communicate(stdin)
    except OSError as err:
        error_message(command, err)
        sys.exit(1)
    if pid.returncode != 0:
        error_message(command, 'pid code {}'.format(pid.returncode), stdout_data, stderr_data)
    return stdout_data, stderr_data

其中

command = [
    'some-bash-command',
    '--fruit=banana',
    '--number=2',
    '--ripe=true',
    # ...
]

一个人如何在用于读取TFRecords的parse_fn中调用此过程?

# import stuff, etc

def from_record(record):
    # see previous posts on recovering TFRecords
    return as_tf

# command as described above

def parse_fn(record):
    parsed = from_record(record)

    # I want this to work but it doesn't
    values_i_want = process(command, stdin, popen_options):

    return values_i_want, parsed['labels']


sess = tf.InteractiveSession()
DATASET_FILENAMES = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(DATASET_FILENAMES).map(lambda r:parse_fn(r)).repeat().batch(2)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

有想法吗?

0 个答案:

没有答案