我在张量流代码中完全重置了计算图吗?

时间:2016-04-28 06:50:42

标签: tensorflow

我正在尝试做一些实验。 在每个小批量之后,我试图重构计算图。 我有一种感觉,虽然有一些问题。当我为第一个小批量生成W1,W2,W3的初始值时,我得到了我期望的更新。然而,我没有得到我期望从第二个小批量开始的更新。是否有可能在每次迭代时检查计算图形是什么样的?

import tensorflow as tf
import numpy as np
bsize = 5
Xset = np.random.uniform(0,1,(60000,6*20)) * 50
Yset = Xset[:,0]

Wone = np.random.normal(0, .35, (6,6))
Wtwo = np.random.normal(0, .35, (6,6))
Wthree = np.random.normal(0, .35, (6,6))

Results = []

for q in range(1):
    for k in range(40):
        from tensorflow.python.framework import ops
        ops.reset_default_graph()
        tf.reset_default_graph()
        tf.InteractiveSession()
        x1 = tf.placeholder(tf.float32, shape=(bsize,6*20))
        y = tf.placeholder(tf.float32, shape=(bsize,1))
        x = tf.reshape(x1,[bsize,6,20])
        InitialState = tf.zeros((6,bsize))
        h = InitialState
        W1 = tf.Variable(tf.convert_to_tensor(Wone,dtype = tf.float32),name = "W1")
        W2 = tf.Variable(tf.convert_to_tensor(Wtwo,dtype = tf.float32),name = "W2")
        W3 = tf.Variable(tf.convert_to_tensor(Wthree,dtype = tf.float32),name = "W3")


#create list
        lis = []
        for q in range(10):
            pit = np.random.uniform(-1,1)
            #print pit

            if(pit<0) or q == 0 or pit==0 or pit > 0:
                lis.append(q)

        for p in lis:
            h = tf.matmul(W1,h) + tf.matmul(W2,tf.transpose(x[:,:,p]))
            h = tf.nn.relu(h)

        hstar = h
        output = tf.matmul(W3,hstar)
        output1 = output[0:1,:]

        loss = tf.reduce_sum(tf.sub(tf.transpose(output1) ,y)*tf.sub(tf.transpose(output1) ,y))

        opt = tf.train.AdamOptimizer()
        opt_operation = opt.minimize(loss)

        for h in range(1):
            with tf.Session() as sess:
                sess.run(tf.initialize_all_variables())
                a,b,RLoss,_ = sess.run([hstar,output,loss,opt_operation], feed_dict = {x1:Xset[(bsize*k):(bsize*k+bsize),:],y:Yset[bsize*k:k*bsize+bsize,None]})

                print RLoss, k

1 个答案:

答案 0 :(得分:0)

只需tf.reset_default_graph()即可。通过检查tf.get_default_graph().as_graph_def() tensorflow.GraphDef架构实施here

的原型,您可以查看图表的外观

特别是,要获取图表中的所有节点名称,您可以执行

[n.name for n in tf.get_default_graph().as_graph_def().node]