Tensorflow:tf.train.Feature错误'预期之一:bytes'

时间:2018-05-31 16:34:01

标签: python numpy tensorflow

所有

我正在尝试将数据序列化到tensorflow中的tfrecords文件。我按照这里的指示: https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_tfrecord.py

如本说明书所示,我需要为每条记录构建example,每个example包含多个feature

但是,我在初始化功能时遇到问题。这是我的测试代码:

import tensorflow as tf
import numpy as np

feature_str=tf.train.Feature(bytes_list=tf.train.BytesList(value = np.array(["a" , "b"])))

feature_int=tf.train.Feature(bytes_list=tf.train.Int64List(value = np.array([32 , 24])))

feature_flo=tf.train.Feature(bytes_list=tf.train.FloatList(value = np.array([32.1 , 24.1 ])))

然而,我得到了以下错误,这是直觉的:

usr / local / lib / python2.7 / dist-packages / h5py / init .py:36:FutureWarning:将issubdtype的第二个参数从float转换为{{ 1}}已弃用。将来,它将被视为np.floating。   从._conv导入register_converters作为_register_converters Traceback(最近一次调用最后一次):   文件“test.py”,第7行,in     feature_int = tf.train.Feature(bytes_list = tf.train.Int64List(value = np.array([32,24]))) TypeError:MergeFrom()的参数必须是同一个类的实例:expected tensorflow.BytesList got tensorflow.Int64List。

我进一步尝试将np.float64 == np.dtype(float).type用于tf.train.BytesList并收到以下错误:

feature_int = tf.train.Feature(bytes_list = tf.train.BytesList(value = np.array([32,24]))) TypeError:32的类型为numpy.int64,但是应该是:bytes

在这个问题上有人能帮帮我吗?我真的很困惑。

谢谢!

1 个答案:

答案 0 :(得分:1)

您需要将bytes提供给bytes_list而不是string。以下给出了所有三种情况的例子:

致TFRecords:

output_file = 'out.tfrecord'
writer = tf.python_io.TFRecordWriter(output_file)

"""Build an Example proto
"""

feature = {}
feature['str'] = tf.train.Feature(
       bytes_list=tf.train.BytesList(value = [b"a",b"b"]))
feature['int'] = tf.train.Feature(int64_list=tf.train.Int64List(value = np.array([32 , 24])))
feature['flo'] = tf.train.Feature(float_list=tf.train.FloatList(value = np.array([32.1 , 24.1 ])))

features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)

writer.close() 

从TFRecords中读取:

  for serialized_example in tf.python_io.tf_record_iterator('out.tfrecord'):
    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    print(example.features.feature['str'].bytes_list.value)
    print(np.array(example.features.feature['int'].int64_list.value))
    print(np.array(example.features.feature['flo'].float_list.value))

<强>输出:

[b'a', b'b']
[32 24]
[32.09999847 24.10000038]