TensorFlow保存未评估的变量?

时间:2018-12-01 07:23:59

标签: python tensorflow

注意:我已阅读Save and Restore的TensorFlow指南

我有一个函数,需要一个常量SparseTensor作为输入,并且需要一段时间才能运行:

sparse = coo_matrix(dense)

sparse_indicies = list(zip(
    sparse.row.astype(np.int64).tolist(), 
    sparse.col.astype(np.int64).tolist()
))

input_tensor = tf.SparseTensor(
    indices     = sparse_indicies,
    values      = (sparse.data).astype(np.float32),
    dense_shape = sparse.shape
)

我只希望能够保存input_tensor并在以后需要时加载它。我没有急于执行。

这是否意味着我必须这样做:

file = 'tensor_{}_{}.ckpt'.format(*sparse.shape)


# save
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    save_path = saver.save(sess, file)

然后

# reload
tf.reset_default_graph()

junk, n_rows, n_cols = file.split('.')[0].split('_')
please_work = tf.get_variable("input_tensor", shape=[n_rows, n_cols])


saver = tf.train.Saver()

with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "i_just_want_one_tensor.ckpt")

0 个答案:

没有答案