Tensorflow:是否可以将TF记录序列示例存储为float16

时间:2016-10-21 20:16:36

标签: tensorflow protocol-buffers

是否可以将tensorflow中的序列示例存储为float16而不是常规float?

我们可以使用16位精度,它将减少我们使用的数据文件的大小,节省约200 GB。

1 个答案:

答案 0 :(得分:0)

我认为下面的剪辑就是这样。

import tensorflow as tf
import numpy as np

# generate the data
data_np = np.array(np.random.rand(10), dtype=np.float16)

with tf.python_io.TFRecordWriter('/tmp/data.tfrecord') as writer:
    # encode the data in a dictionary of features
    data = {'raw': tf.train.Feature(
        # the feature has a type ByteList
        bytes_list=tf.train.BytesList(
            # encode the data into bytes
            value=[data_np.tobytes()]))}
    # create a example from the features
    example = tf.train.Example(features=tf.train.Features(feature=data))
    # write the example to a TFRecord file
    writer.write(example.SerializeToString())

def _parse_tfrecord(example_proto):
    # describe how the TFRecord example will be interpreted
    features = {'raw': tf.FixedLenFeature((), tf.string)}
    # parse the example (dict of features) from the TFRecord
    parsed_features = tf.parse_single_example(example_proto, features)
    # decode the bytes as float16 array
    return tf.decode_raw(parsed_features['raw'], tf.float16)

def tfrecord_input_fn():
    # read the dataset
    dataset = tf.data.TFRecordDataset('/tmp/data.tfrecord')
    # parse each example of the dataset
    dataset = dataset.map(_parse_tfrecord)
    iterator = dataset.make_one_shot_iterator()

    return iterator.get_next()

# get an iterator over the TFRecord
it = tfrecord_input_fn()
# make a session and evaluates the Tensor
sess = tf.Session()
recovered_data = sess.run(it)
print(recovered_data == data_np)