将张量值保存为二进制格式的文件的最佳方法是什么?

时间:2018-01-08 05:03:48

标签: python tensorflow floating-point

我正在尝试将张量值保存为二进制格式的文件。 特别是我试图将float32张量值保存为二进制格式(IEEE-754格式)。你能帮帮我吗?

import tensorflow as tf

x = tf.constant([[1.0, 2.0, 3.0], [5.5, 4.3, 2.5]])

# how to save tensor x as binary format ?? 

1 个答案:

答案 0 :(得分:2)

建议的方法是检查您的模型。如Prefetch() object中所述,您可以创建一个Saving and Restoring programmer's guide对象,可以选择指定要保存哪些变量/可保存对象。然后,只要您想保存张量的值,就可以调用tf.train.Saver对象的save()方法:

saver = tf.train.Saver(...)

#...

saver.save(session, 'my-checkpoints', global_step = step)

..其中第二个参数(上例中的'my-checkpoints')是存储检查点二进制文件的目录的路径。

另一种方法是评估单个张量(将是NumPy ndarrays),然后将单个ndarray保存到NPY文件(通过tf.train.Saver)或将多个ndarray保存到单个NPZ存档(通过numpy.save()或{ {3}}):

np.save('x.npy', session.run(x), allow_pickle = False)