pywrap_tensorflow检查点读取器因分区变量的ftrl状态而失败

时间:2019-07-08 08:02:44

标签: python tensorflow

我正在尝试从检查点加载张量,但是我发现pywrap_tensorflow接口无法加载分区变量的FTRL状态。

我正在使用python 3和tensorflow 1.12.0。

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope

partitioner = partitioned_variables.min_max_variable_partitioner(max_partitions=2, min_slice_size=5)
with variable_scope.variable_scope('linear', partitioner=partitioner):
    var = tf.get_variable('my_var', [10, 1], partitioner=partitioner)
    # I find that if the following two lines are put out of the variable_scope('linear'), the error will not exist, but that's not what I need.
    # Actually, I ask this question because I got such errors when loading checkpoints for tf.estimator.LinearClassifier, and the code I post is a simplification of the source codes leading to such errors.
    # see Line 730 and 746 of https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/canned/linear.py
    opt = tf.train.FtrlOptimizer(0.01)
    optmin = opt.minimize(tf.reduce_sum(tf.square(var)))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().save(sess, 'dist_ckpt/model.ckpt', write_meta_graph=False, write_state=False)

reader = pywrap_tensorflow.NewCheckpointReader('dist_ckpt/model.ckpt')
var = reader.get_variable_to_shape_map()

for key in var:
    if 'Ftrl' in key:
        print(key)
        print(reader.get_tensor(key))

执行reader.get_tensor(key)时收到以下错误消息:

  

tensorflow.python.framework.errors_impl.InvalidArgumentError:不   有足够的切片用于张量   linear / my_var / part_0 / Ftrl_1还原为slice_spec:-:-

我已经意识到错误是因为我无法仅加载张量的一部分。但是,我不知道如何加载张量(即linear / my_var的FTRL状态)。

0 个答案:

没有答案