注意:我已阅读Save and Restore的TensorFlow指南
我有一个函数,需要一个常量SparseTensor
作为输入,并且需要一段时间才能运行:
sparse = coo_matrix(dense)
sparse_indicies = list(zip(
sparse.row.astype(np.int64).tolist(),
sparse.col.astype(np.int64).tolist()
))
input_tensor = tf.SparseTensor(
indices = sparse_indicies,
values = (sparse.data).astype(np.float32),
dense_shape = sparse.shape
)
我只希望能够保存input_tensor
并在以后需要时加载它。我没有急于执行。
这是否意味着我必须这样做:
file = 'tensor_{}_{}.ckpt'.format(*sparse.shape)
# save
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, file)
然后
# reload
tf.reset_default_graph()
junk, n_rows, n_cols = file.split('.')[0].split('_')
please_work = tf.get_variable("input_tensor", shape=[n_rows, n_cols])
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "i_just_want_one_tensor.ckpt")