Tensorflow:排除op向后传递

时间:2017-10-10 00:48:15

标签: python-3.x tensorflow

我试图在tensorflow中创建简单的4层(包括输入层)神经网络来处理MNIST数据。它适用于以下图形定义,

Graph1

with tf.name_scope("layer1"):
    z1 = tf.matmul(x,w1) + b1
    a1 = tf.nn.sigmoid(z1)
with tf.name_scope("layer2"):
    z2 = tf.matmul(a1,w2) + b2
    a2 = tf.nn.sigmoid(z2)
with tf.name_scope("layer3"):
    z3 = tf.matmul(a2,w3) + b3

接下来,我希望在每次激活后执行一些操作,然后创建如下图表,

Graph2

with tf.name_scope("layer1"):
    x = transform(x)
    z1 = tf.matmul(x,w1) + b1
    a1 = tf.nn.sigmoid(z1)
with tf.name_scope("layer2"):
    a1 = transform(a1)
    z2 = tf.matmul(a1,w2) + b2
    a2 = tf.nn.sigmoid(z2)
with tf.name_scope("layer3"):
    a2 = transform(a2)
    z3 = tf.matmul(a2,w3) + b3

其中,transform是一个用tensorflow ops定义的python函数,用于生成与输入矩阵相同大小的输出矩阵,并进行一些数学变换。 Graph1训练得很好,我的表现非常好,但Graph2的问题是只有" layer3"变得可训练,即w3和b3和tensorflow从训练集中排除其他变量。我尝试将var_list添加到优化器但没有运气。 tf.stop_gradient适用于变量,但如何将它用于python函数?更具体地说,我想排除" tranform"函数本身用于向后传递,仅在前向传递中使用它。在此link上,您将找到在tensorboard中创建的图表。在tensorflow中有没有办法做到这一点? TIA !!

1 个答案:

答案 0 :(得分:0)

经过一些研究,我得出的结论是,在tensorflow中不可能这样做,因为,tensorflow内部构建了一个图形结构,除非在某个时刻有tf.stop_gradient调用,否则它将被完全向后遍历。所以必须要写出自己的前后传球。希望这有助于将来的人!!