如何将Float数组/列表转换为TFRecord?

时间:2018-03-31 13:39:50

标签: python tensorflow

这是用于将数据转换为TFRecord的代码

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

 def _bytes_feature(value):
   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _floats_feature(value):
   return tf.train.Feature(float_list=tf.train.FloatList(value=value))

with tf.python_io.TFRecordWriter("train.tfrecords") as writer:
    for row in train_data:
        prices, label, pip = row[0],row[1],row[2]
        prices = np.asarray(prices).astype(np.float32)
        example = tf.train.Example(features=tf.train.Features(feature={
                                           'prices': _floats_feature(prices),
                                           'label': _int64_feature(label[0]),
                                           'pip': _floats_feature(pip)
    }))
        writer.write(example.SerializeToString())

功能价格是一个形状数组(1,288)。它转换成功!但是当使用解析函数和数据集API解码数据时。

def parse_func(serialized_data):
    keys_to_features = {'prices': tf.FixedLenFeature([], tf.float32),
                    'label': tf.FixedLenFeature([], tf.int64)}

    parsed_features = tf.parse_single_example(serialized_data, keys_to_features)
    return parsed_features['prices'],tf.one_hot(parsed_features['label'],2)

它给了我错误

  

C:\ tf_jenkins \ workspace \ rel-win \ M \ windows -gpu \ PY \ 36 \ tensorflow \ core \ framework \ op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc:240失败:无效参数:键:价格。无法解析序列化的示例。   2018-03-31 15:37:11.443073:WC:\ tf_jenkins \ workspace \ rel-win \ M \ windows -gpu \ PY \ 36 \ tensorflow \ core \ framework \ op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc失败:240:参数无效:关键:价格。无法解析序列化的示例。   2018-03-31 15:37:11.443313:W C:\ tf_jenkins \ workspace \ rel-win \ M \ windows-gpu \ raise type(e)(node_def,op,message)   PY \ 36 \ tensortensorflow.python.framework.errors_impl.InvalidArgumentError:关键:价格。无法解析序列化的示例。        [[Node:ParseSingleExample / ParseSingleExample = ParseSingleExample [Tdense = [DT_INT64,DT_FLOAT],dense_keys = [" label"," price"],dense_shapes = [[],[]], num_sparse = 0,sparse_keys = [],sparse_types = []](arg0,ParseSingleExample / Const,ParseSingleExample / Const_1)]]        [[Node:IteratorGetNext_1 = IteratorGetNextoutput_shapes = [[?],[?,2]],output_types = [DT_FLOAT,DT_FLOAT],_ device =" / job:localhost / replica:0 / task:0 / device:CPU :0"]] FL   ow \ core \ framework \ op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc:240失败:无效参数:密钥:价格。无法解析序列化的示例。

5 个答案:

答案 0 :(得分:8)

我发现了问题。不使用tf.FixedLenFeature来解析数组,而是使用tf.FixedLenSequenceFeature

答案 1 :(得分:1)

您不能将n维数组存储为浮点数,因为浮点数是简单列表。您必须执行pricesprices.tolist()展平为列表。如果需要从展平浮动特征中恢复n维数组,则可以执行prices = np.reshape(float_feature, original_shape)

答案 2 :(得分:1)

如果特征是固定的一维数组,则使用tf.FixedLenSequenceFeature根本不正确。如文档所述,tf.FixedLenSequenceFeature用于维数为2或更高的输入数据。 在此示例中,您需要展平价格数组,使其变为(288,),然后对部分解码,您需要提及数组维数。

编码:

example = tf.train.Example(features=tf.train.Features(feature={
                                       'prices': _floats_feature(prices.tolist()),
                                       'label': _int64_feature(label[0]),
                                       'pip': _floats_feature(pip)

解码:

keys_to_features = {'prices': tf.FixedLenFeature([288], tf.float32),
                'label': tf.FixedLenFeature([], tf.int64)}

答案 3 :(得分:0)

我在粗心地修改某些脚本时遇到了同样的问题,这是由于数据形状略有不同造成的。我必须更改形状以匹配预期的形状,例如(A, B)(1, A, B)。我使用np.ravel()进行展平。

答案 4 :(得分:0)

float32文件中读取TFrecord数据列表对我来说完全一样。

使用sess.run([time_tensor, frequency_tensor, frequency_weight_tensor])执行tf.FixedLenFeature时,我得到无法解析序列化的示例,尽管tf.FixedLenSequenceFeature似乎工作正常。

我读取文件的功能格式(有效的格式)如下: feature_format = { 'time': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequencies': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True), 'frequency_weights': tf.FixedLenSequenceFeature([], tf.float32, allow_missing = True) }

编码部分是:

feature = { 'time': tf.train.Feature(float_list=tf.train.FloatList(value=[*some single value*]) ), 'frequencies': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ), 'frequency_weights': tf.train.Feature(float_list=tf.train.FloatList(value=*some_list*) ) }

这种情况发生在Debian机器上的TensorFlow 1.12没有GPU卸载的情况下(即只有与TensorFlow一起使用的CPU)

我这边有滥用吗?还是代码或文档中的错误?如果可以使任何人受益,我可以考虑贡献/上传任何修复程序。