Tensorflow:使用sparse_placeholder恢复模型

时间:2017-10-24 14:09:01

标签: python tensorflow

我需要保存并恢复使用sparse_placeholder的模型,但是我收到错误消息: fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True) df.groupby('State')['Population'].plot(kind='line', linestyle='--', alpha=0.5, marker='o', legend=True, ax=axes[0]) axes[0].set_ylabel('Pop') df.groupby('State')['Temp'].plot(kind='line', linestyle='--', alpha=0.5, marker='o', legend=True, ax=axes[1]) axes[1].set_ylabel('Temp') axes[0].tick_params(axis='both', which='both', labelsize=7, labelbottom=True) axes[1].tick_params(axis='both', which='both', labelsize=7)

这是一种简单的方法来重现当我需要恢复存在sparse_tensor的模型时出现的错误:

KeyError: "The name 'w1:0' refers to a Tensor which does not exist. The operation, 'w1', does not exist in the graph."

在这里我们可以看到一个类似的图,其中使用了正常张量:

import tensorflow as tf
import numpy as np


def train_sparse():
    w1 = tf.sparse_placeholder(tf.float64, shape=None, name="w1")
    b1 = tf.Variable(np.ones((2, 2)) * 1.0, name="bias")
    operation = tf.sparse_tensor_dense_matmul(w1, b1, name="op1")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        w1_value = tf.SparseTensorValue([[1, 1]], [5], [2, 2])
        print sess.run(operation, {w1: w1_value})
        saver.save(sess, 'my_test_model')


def test_sparse():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('my_test_model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        print(sess.run('bias:0'))
        graph = tf.get_default_graph()
        op_to_restore = graph.get_tensor_by_name("op1/SparseTensorDenseMatMul:0")
        w1 = graph.get_tensor_by_name("w1:0")
        w1_value = tf.SparseTensorValue([[1, 1]], [5], [2, 2])

        print sess.run(op_to_restore, {w1: w1_value})

if __name__ == "__main__":
    train_sparse()
    test_sparse()

有没有人知道如何修复它?

作为一种解决方法,我能够将数据作为正常张量传递,并在图中更改为sparse_tensor。但是,它需要不必要的转换。

1 个答案:

答案 0 :(得分:2)

Tensorflow保存稀疏占位符的索引,值和形状。它会相应地为它们添加后缀,因此名为w1的SparsePlaceholder将在保存的图形中成为名为w1/indicesw1/valuesw1/shape的3个占位符。

我改变了你的例子,所以它变得更清晰了:

import tensorflow as tf
import numpy as np


def train_sparse():
    w1 = tf.sparse_placeholder(tf.float64, shape=None, name="w1")
    b1 = tf.Variable(np.ones((2, 2)) * 1.0, name="bias")
    operation = tf.sparse_tensor_dense_matmul(w1, b1, name="op1")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        w1_value = tf.SparseTensorValue([[1, 1]], [5], [2, 2])
        print sess.run(operation, {w1: w1_value})
        saver.save(sess, 'my_test_model')

def test_sparse():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('my_test_model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        print(sess.run('bias:0'))
        graph = tf.get_default_graph()
        op_to_restore = graph.get_tensor_by_name("op1/SparseTensorDenseMatMul:0")

        # NEW PART 

        w1_indices = graph.get_tensor_by_name("w1/indices:0")
        w1_indices_value = [[1,1]]

        w1_size = graph.get_tensor_by_name("w1/indices:0")
        w1_size_value = [5]

        w1_values = graph.get_tensor_by_name("w1/values:0")
        w1_values_value = [5]

        w1_shape = graph.get_tensor_by_name("w1/shape:0")
        w1_shape_value = [2,2]

        print sess.run(op_to_restore, {w1_indices: w1_indices_value,
                                       w1_values: w1_values_value,
                                       w1_shape: w1_shape_value})

if __name__ == "__main__":
    train_sparse()
    test_sparse()