将操作从一个图形复制到另一个图形张量流

时间:2017-10-14 08:20:05

标签: tensorflow

import tensorflow as tf  
import numpy as np  
import os  
import data_helpers  
from tensorflow.contrib import learn

# Parameters
# ==================================================

# Data Parameters
tf.flags.DEFINE_string("eval_file", "./text/tokenizedSmallText.txt", "Data source for the positive data.")  

# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")  
tf.flags.DEFINE_string("checkpoint_dir", "./runs/ThuOct121628262017/checkpoints", "Checkpoint directory from training run")  
tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data")  

# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")  
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")  



FLAGS = tf.flags.FLAGS  
FLAGS._parse_flags()  
print("\nParameters:")  
for attr, value in sorted(FLAGS.__flags.items()):  
    print("{}={}".format(attr.upper(), value))
    print("")`  

x_raw= data_helpers.load_data_and_labels(FLAGS.eval_file)

# Map data into vocabulary
vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab")
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
x_test = np.array(list(vocab_processor.transform(x_raw)))

print("\nEvaluating...\n")

# Evaluation
# ==================================================
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph1 = tf.Graph()
graph2 = tf.Graph()

with graph1.as_default():
    session_conf = tf.ConfigProto(
      allow_soft_placement=FLAGS.allow_soft_placement,
      log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)

# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)

embedded_W = graph1.get_tensor_by_name("embedding/W:0")
embedded_W2 = tf.contrib.copy_graph.copy_op_to_graph(embedded_W,graph2,[])


with graph2.as_default():
    session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement)
    sess2 = tf.Session(config=session_conf)
    with sess2.as_default():
        tf.global_variables_initializer().run(session=sess2)
        print(embedded_W2)
        print(sess2.run(embedded_W2))

所以这是我的代码。
我想要做的是将操作从graph1复制到graph2 这就是我得到的。 Graph1已经是训练模型,我想重新训练模型,所以在继续初始化CNN权重和偏差时,出现了这个问题。

我尝试不构建另一个graph2,这意味着使用graph1作为第一次训练,但需要向Word_embedded_vectors添加更多单词(= graph.get_tensor_by_name(embedding / W:0))

  1. 如果还有其他方法通过不使用新图表进行再培训??
  2. 如果没有,我想解决以下错误消息。

    追踪(最近一次通话): 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第1327行,在_do_call return fn(* args) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第1306行,在_run_fn中 status,run_metadata) 文件&#34; /usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/contextlib.py" ;,第88行,在退出< /强> 下一个(self.gen) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py",第466行,在raise_exception_on_not_ok_status中 pywrap_tensorflow.TF_GetCode(状态)) tensorflow.python.framework.errors_impl.FailedPreconditionError:尝试使用未初始化的值嵌入/ W  [[Node:_retval_embedding / W_0_0 = _RetvalT = DT_FLOAT,index = 0,_device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]]

    在处理上述异常期间,发生了另一个异常:

    追踪(最近一次通话): File&#34; /Users/jj/eclipse-workspace/qna_beta/e01/test.py" ;,第105行,在 打印(sess2.run(embedded_W2)) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第895行,在运行中 run_metadata_ptr) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第1124行,在_run中 feed_dict_tensor,options,run_metadata) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第1321行,在_do_run中 选项,run_metadata) 文件&#34; /Users/jj/tensorflow3/lib/python3.6/site-packages/tensorflow/python/client/session.py" ;,第1340行,在_do_call 提升类型(e)(node_def,op,message) tensorflow.python.framework.errors_impl.FailedPreconditionError:尝试使用未初始化的值嵌入/ W  [[Node:_retval_embedding / W_0_0 = _RetvalT = DT_FLOAT,index = 0,_device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]]

  3. Tensorflow的版本是1.3.0

0 个答案:

没有答案