在Tensorflow

时间:2017-10-02 19:46:08

标签: python python-3.x tensorflow

我正在尝试使用tensorflow序列化可变长度的训练数据,但我无法重建它,因为我无法想出一种方法来传递每个训练实例的长度。

序列化数据

import tensorflow as tf
import numpy as np

data = [["foo", "bar", "baz"], ["the", "quick", "brown", "fox"]]

def _bytes_feature(val):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[val]))

def _int64_feature(val):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))

def serialize_data(input_data, path):
    """ iterate over and serialize data. """
    writer = tf.python_io.TFRecordWriter(path)
    datums = len(input_data)
    for i in range(datums):
        data_len = len(input_data[i])
        raw_data = np.array(input_data[i]).tostring()
        this_example = tf.train.Example(
            features = tf.train.Features(feature={
                "raw_data": _bytes_feature(raw_data),
                "data_len": _int64_feature(data_len)
            }))

        writer.write(this_example.SerializeToString())
    writer.close()

if __name__ == "__main__":
    serialize_data(data, "./out.tfrecord")

我的解决方案是记录每个数据点的长度并将其打包到每个示例中,然后,在读取数据进行训练时,使用长度来重新整形原始数据。问题是当我重建data_len时它是tf.Tensor并且不能用于重塑原始数据。

错误

  

TypeError:int()参数必须是字符串,类似字节的对象或数字,而不是'Tensor'

导入数据(产生错误的代码)

dataset = tf.contrib.data.TFRecordDataset(["out.tfrecord"])

def extract_raw_data(my_example):
    features = {
        "raw_data": tf.FixedLenFeature([], tf.string),
        "data_len": tf.FixedLenFeature([], tf.string),
    }
    parsed_features = tf.parse_single_example(my_example, features)
    data = tf.decode_raw(parsed_features['raw_data'], tf.string)
    len_data = tf.decode_raw(parsed_features['data_len'], tf.int32)
    # data.set_shape() <-- use len_data here to reshape data
    return data

dataset = dataset.map(extract_raw_data)

我考虑过的一个解决方案是找到所有训练数据实例的最大长度并填充每个实例,然后简单地对重塑值进行硬编码(就像处理图像时那样)但是我想知道是否有办法传递每个训练实例的长度并重建它。

感谢。

0 个答案:

没有答案