张量流预测

时间:2019-02-27 08:59:52

标签: tensorflow tensorflow-datasets

训练完我的“参数”(w1,w2,Conv网络中过滤器的权重)后,将其保存为parameter = sess.run(参数)

我拍摄一张图像img = [1,64,64,3],并将其传递给mypredict(x,parameters)函数进行预测,但会产生错误。功能如下。关于出什么问题的任何建议。

def forward_propagation(X,参数):

W1 = parameters['W1']
W2 = parameters['W2']


Z1 = tf.nn.conv2d(X,W1,strides=[1,1,1,1],padding='SAME')

A1 = tf.nn.relu(Z1)

P1 = tf.nn.max_pool(A1,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')

Z2 = tf.nn.conv2d(P1,W2,strides=[1,1,1,1],padding='SAME')

A2 = tf.nn.relu(Z2)

P2 = tf.nn.max_pool(A2,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')

P2 = tf.contrib.layers.flatten(P2)

Z3 = tf.contrib.layers.fully_connected(P2,num_outputs=6,activation_fn=None)

return Z3

def mypredict(X,par):

W1 = tf.convert_to_tensor(par["W1"])
W2 = tf.convert_to_tensor(par["W2"])
params = {"W1": W1,
          "W2": W2}

x = tf.placeholder("float", [1,64,64,3])

z3 = forward_propagation_for_predict(x, params)

p = tf.argmax(z3)

sess = tf.Session()
prediction = sess.run(p, feed_dict = {x:X})

return prediction

我使用相同的函数“ forward_propagation”来训练权重,但是当我传递单个图像时,它不起作用。

错误:


FailedPreconditionError Traceback(最近一次呼叫最近) _do_call中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(self,fn,* args)    1138试试: -> 1139返回fn(* args)    除errors.OpError为e以外的其他1140: _run_fn中的

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(session,feed_dict,fetch_list,target_list,选项,run_metadata)    1120 feed_dict,fetch_list,target_list, -> 1121状态,run_metadata)    1122

退出中的

/opt/conda/lib/python3.6/contextlib.py(自身,类型,值,回溯)      88试试: ---> 89下一个(self.gen)      90,除了StopIteration:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py中的升高_异常_on_not_ok_status()     第465章(一更) -> 466 pywrap_tensorflow.TF_GetCode(状态))     终于467:

FailedPreconditionError:尝试使用未初始化的值fully_connected_1 / biases      [[节点:fully_connected_1 / biases / read = IdentityT = DT_FLOAT,_class = [“ loc:@ fully_connected_1 / biases”],_device =“ / job:localhost / replica:0 / task:0 / cpu:0”]] < / p>

在处理上述异常期间,发生了另一个异常:

FailedPreconditionError Traceback(最近一次呼叫最近)  在()中 ----> 1个pred = mypredict(t,pp)       2

mypredict中的

(X,par)      49      50 sess = tf.Session() ---> 51预测= sess.run(p,feed_dict = {x:X})      52      53回报预测

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py在运行中(self,fetches,feed_dict,options,run_metadata)     787尝试:     (788)第788章 -> 789 run_metadata_ptr)     790(如果运行)     791 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py在_run中(自身,句柄,提取,feed_dict,选项,run_metadata)     995(如果final_fetches或final_targets:     996个结果= self._do_run(handle,final_targets,final_fetches, -> 997 feed_dict_string,选项,run_metadata)     998其他:     999个结果= []

_do_run中的

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(自身,句柄,target_list,fetch_list,feed_dict,选项,run_metadata)    1130如果handle为None:    1131返回self._do_call(_run_fn,self._session,feed_dict,fetch_list, -> 1132 target_list,选项,run_metadata)    1133其他:    1134 return self._do_call(_prun_fn,self._session,handle,feed_dict,

_do_call中的

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(self,fn,* args)    1150,除了KeyError:    1151通过 -> 1152提高类型(e)(node_def,op,message)    1153    1154 def _extend_graph(self):

FailedPreconditionError:尝试使用未初始化的值fully_connected_1 / biases

1 个答案:

答案 0 :(得分:0)

还必须从完全连接的层加载参数。

但是,我建议还是使用TensorFlow's Saver and Restore functions

为了保存,下面是一个玩具示例:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000) # saving model after 1000 steps

存储以下文件:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

因此,要进行还原,您可以先重新创建网络,然后加载参数:

with tf.Session() as sess:
 recreated_net = tf.train.import_meta_graph('my_test_model-1000.meta')
 recreated_net.restore(sess, tf.train.latest_checkpoint('./'))