是否可以将tensorflow中的序列示例存储为float16而不是常规float?
我们可以使用16位精度,它将减少我们使用的数据文件的大小,节省约200 GB。
答案 0 :(得分:0)
我认为下面的剪辑就是这样。
import tensorflow as tf
import numpy as np
# generate the data
data_np = np.array(np.random.rand(10), dtype=np.float16)
with tf.python_io.TFRecordWriter('/tmp/data.tfrecord') as writer:
# encode the data in a dictionary of features
data = {'raw': tf.train.Feature(
# the feature has a type ByteList
bytes_list=tf.train.BytesList(
# encode the data into bytes
value=[data_np.tobytes()]))}
# create a example from the features
example = tf.train.Example(features=tf.train.Features(feature=data))
# write the example to a TFRecord file
writer.write(example.SerializeToString())
def _parse_tfrecord(example_proto):
# describe how the TFRecord example will be interpreted
features = {'raw': tf.FixedLenFeature((), tf.string)}
# parse the example (dict of features) from the TFRecord
parsed_features = tf.parse_single_example(example_proto, features)
# decode the bytes as float16 array
return tf.decode_raw(parsed_features['raw'], tf.float16)
def tfrecord_input_fn():
# read the dataset
dataset = tf.data.TFRecordDataset('/tmp/data.tfrecord')
# parse each example of the dataset
dataset = dataset.map(_parse_tfrecord)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
# get an iterator over the TFRecord
it = tfrecord_input_fn()
# make a session and evaluates the Tensor
sess = tf.Session()
recovered_data = sess.run(it)
print(recovered_data == data_np)