恢复训练有素的tensorflow模型python

时间:2019-07-19 16:18:56

标签: python-3.x tensorflow machine-learning neural-network restore

我建立了一个模型,该模型包括三个部分,两个辅助模型和一个使用它们的主模型。 这两个辅助模型都已经过训练,现在我想在主模型中使用它们。

即使我给他们起了个名字,我也可以访问训练后的变量,但不能访问模型的各层(即tf.layers.dense)。

我从一个仅包含常量的tensorflow小型程序开始,然后添加了常规变量,并且在这两个程序中,我都可以访问值,然后添加了一层tf.layers.danse,在那里我可以无法访问。

import tensorflow as tf
import numpy as np

# First Model
g1 = tf.Graph()
with g1.as_default():
    features = 2
    x = tf.placeholder(tf.float32, [None, features], name='x')
    y_ = tf.placeholder(tf.float32, [None, 1], name='y_')
    z = tf.Variable(4., name='z')

    q = tf.layers.dense(x, units=1, name= 'q')
    loss = tf.reduce_mean(tf.pow(q - y_, 2), name='loss')
    update = tf.train.GradientDescentOptimizer(0.001).minimize(loss)

    s1 = tf.train.Saver()

# Training First Model
with tf.Session(graph=g1) as sess:

    # initialize all of the variables in the session
    sess.run(tf.global_variables_initializer())

    data_x = np.array([[2, 4], [3, 9], [4, 16], [6, 36], [7, 49]])
    data_y = np.array([[70], [110], [165], [390], [550]])

    for i in range(1000):
        sess.run(update, feed_dict={x: data_x, y_: data_y})
        if i % 333 == 0:
            print('Iteration:', i, ' loss:', loss.eval(session=sess, feed_dict={x: data_x, y_: data_y}))

    s1.save(sess, 'g1')


# I deleted the second model because it is not necessary for the question

# The Main Model
g3 = tf.Graph()
with g3.as_default():
    tf.train.import_meta_graph('g1.meta', import_scope='g1')
    x1, y1, z1, loss1 = [g3.get_tensor_by_name('g1/%s:0' % name) for name in ('x', 'y_', 'z', 'loss')]

    g = loss1 + z1

# create separate loaders - we need to load variables from different files
with g3.as_default():
    s33 = tf.train.Saver(var_list={'z': z1})

data_x = np.array([[2, 4], [3, 9], [4, 16], [6, 36], [7, 49]])
data_y = np.array([[70], [110], [165], [390], [550]])

feed_dict = {x1: data_x, y1: data_y}
print('create data')

with tf.Session(graph=g3) as sess:
    s33.restore(sess, './g1')

    # check if values were actually restored, not re-initialized
    g_value = sess.run([g], feed_dict=feed_dict)
    print("g = ", g_value)

我希望程序打印出变量g的值。由于g的值= loss1 + z1,所以我预计loss1的值将被实际打印出来(当z1 = 4时为+ z1)。 所以我在下一行中输入feed_dict

     g_value = sess.run ([g], feed_dict = feed_dict)

为了使程序可以计算取决于x和y_的loss1值。

但是出现以下错误:

回溯(最近通话最近):   

中的文件“ /home/xxx/pytoh/Merge_two_models_in_TensorFlow/Main.py”,第60行
g_value = sess.run([g], feed_dict=feed_dict)

文件“ /home/xxx/pytoh/Merge_two_models_in_TensorFlow/project/python3.5/site-packages/tensorflow/python/client/session.py”,第950行,

在运行run_metadata_ptr)

文件“ /home/xxx/pytoh/Merge_two_models_in_TensorFlow/project/python3.5/site-packages/tensorflow/python/client/session.py”,行1173,位于_run

feed_dict_tensor, options, run_metadata)

文件“ /home/xxx/pytoh/Merge_two_models_in_TensorFlow/project/python3.5/site-packages/tensorflow/python/client/session.py”,行1350,位于_do_run

run_metadata)

文件“ /home/xxx/pytoh/Merge_two_models_in_TensorFlow/project/python3.5/site-packages/tensorflow/python/client/session.py”,行1370,在_do_call中

raise type(e)(node_def, op, message)

tensorflow.python.framework.errors_impl.FailedPreconditionError:尝试使用未初始化的值g1 / q / bias      [[节点g1 / q / bias / read(在/Main.py:41处定义)]]

“ g1 / q / bias / read”的原始堆栈跟踪:   

中的文件“ /Main.py”,第41行

tf.train.import_meta_graph('g1.meta',import_scope ='g1')

在import_meta_graph中的文件“ /project/python3.5/site-packages/tensorflow/python/training/saver.py”,行1449     ** kwargs)[0]

文件“ /project/python3.5/site-packages/tensorflow/python/training/saver.py”,行1473,在_import_meta_graph_with_return_elements中     ** kwargs))

文件“ /project/python3.5/site-packages/tensorflow/python/framework/meta_graph.py”,行857,在import_scoped_meta_graph_with_return_elements中     return_elements = return_elements)

文件“ /project/python3.5/site-packages/tensorflow/python/util/deprecation.py”,第507行,位于new_func中     返回func(* args,** kwargs)

文件“ /project/python3.5/site-packages/tensorflow/python/framework/importer.py”,行443,在import_graph_def中     _ProcessNewOps(graph)

文件_ProcessNewOps中的文件“ /project/python3.5/site-packages/tensorflow/python/framework/importer.py”,第236行     用于图表中的new_op。_add_new_tf_operations(compute_devices = False):#pylint:disable = protected-access

文件“ /project/python3.5/site-packages/tensorflow/python/framework/ops.py”,行3751,在_add_new_tf_operations中     用于c_api_util.new_tf_operations(self)中的c_op

文件“ /project/python3.5/site-packages/tensorflow/python/framework/ops.py”,行3751,在     用于c_api_util.new_tf_operations(self)中的c_op

文件_create_op_from_tf_operation中的文件“ /project/python3.5/site-packages/tensorflow/python/framework/ops.py”,行3641     ret = Operation(c_op,self)

文件“ /project/python3.5/site-packages/tensorflow/python/framework/ops.py”,第2005行, init     self._traceback = tf_stack.extract_stack()

0 个答案:

没有答案