我想加速一个平均多个TensorFlow检查点的工具,但为了简单起见,我想我只需要加载一个检查点,可能修改一些变量并将其保存到磁盘。
current implementation将所有变量加载到numpy数组(3秒),为每个变量(20s)准备tf变量(tf.get_variable()
),占位符和assign_ops,执行初始化所有变量的会话( 16s),运行所有分配(81s),最后它将检查点存储到磁盘(24s)。总时间 144秒。
我的替代实现使用tf.get_variable(name, shape=numpy_array.shape, initializer=tf.constant_initializer(numpy_array))
而没有占位符或assign_ops,因此它将总时间减少到 57秒。但是,*.meta
文件还存储所有常量初始值设定项(因此它与主检查点数据文件一样大,这不是我想要的),当我在更大的检查点上应用它时,由于{ {3}}
如果加载所有变量需要3秒钟,我相信存储它们应该花费少于141秒甚至少于54秒。有没有办法将dict {var_name1: numpy_array1,...}
写入tf检查点文件(重用另一个检查点的元图)而无需运行tf会话?我试图跟踪2GB tf limit中的链接但没有成功。