我正在尝试从检查点加载张量,但是我发现pywrap_tensorflow
接口无法加载分区变量的FTRL状态。
我正在使用python 3和tensorflow 1.12.0。
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
partitioner = partitioned_variables.min_max_variable_partitioner(max_partitions=2, min_slice_size=5)
with variable_scope.variable_scope('linear', partitioner=partitioner):
var = tf.get_variable('my_var', [10, 1], partitioner=partitioner)
# I find that if the following two lines are put out of the variable_scope('linear'), the error will not exist, but that's not what I need.
# Actually, I ask this question because I got such errors when loading checkpoints for tf.estimator.LinearClassifier, and the code I post is a simplification of the source codes leading to such errors.
# see Line 730 and 746 of https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/canned/linear.py
opt = tf.train.FtrlOptimizer(0.01)
optmin = opt.minimize(tf.reduce_sum(tf.square(var)))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.Saver().save(sess, 'dist_ckpt/model.ckpt', write_meta_graph=False, write_state=False)
reader = pywrap_tensorflow.NewCheckpointReader('dist_ckpt/model.ckpt')
var = reader.get_variable_to_shape_map()
for key in var:
if 'Ftrl' in key:
print(key)
print(reader.get_tensor(key))
执行reader.get_tensor(key)
时收到以下错误消息:
tensorflow.python.framework.errors_impl.InvalidArgumentError:不 有足够的切片用于张量 linear / my_var / part_0 / Ftrl_1还原为slice_spec:-:-
我已经意识到错误是因为我无法仅加载张量的一部分。但是,我不知道如何加载张量(即linear / my_var的FTRL状态)。