我的网络输入来自包含int32的文件。它们存储为.tfrecords,如下所示:
writer = tf.python_io.TFRecordWriter(output_file)
with tf.gfile.FastGFile(file_path, 'rb') as f:
data = f.read()
example = tf.train.Example(features=tf.train.Features(feature={
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])) }))
writer.write(example.SerializeToString())
然后我像这样读取tfrecords文件:
with tf.name_scope(self.name):
filename_queue = tf.train.string_input_producer([path])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={ 'data': tf.FixedLenFeature([], tf.string) })
data = features['data']
在读取tfrecords之后,我得到了这样的字符串张量:
Tensor("X/ParseSingleExample/ParseSingleExample:0", shape=(), dtype=string)
我想首先将其转换为int32,因为这就是初始数据所代表的含义。之后,我需要结束一个浮点张量,有人可以指向我正确的方向吗?
tensorflow的PS Im新手,请告诉我是否可以提供更多有用的信息
答案 0 :(得分:3)
这应该有帮助
data = features['data']
decoded = tf.decode_raw(data, tf.int32)
这将输出dtype tf.int32
的张量。然后,您可以调整其形状并投射到tf.float32
decoded = tf.reshape(decoded, shape)
decoded = tf.cast(decoded, tf.float32)
如果要检查tf.Session
之外的tfrecords文件的内容
for str_rec in tf.python_io.tf_record_iterator('file.tfrecords'):
example = tf.train.Example()
example.ParseFromString(str_rec)
data_str = example.features.feature['data'].bytes_list.value[0])
decoded = np.fromstring(data_str, dtype)
要验证张量的内容,您可以按照此answer
中的说明在图中注入打印节点。# Add print operation
decoded = tf.Print(decoded, [decoded])