在尝试应用tf.while_loop
几天之后,我仍然无法理解它是如何工作的(或者更确切地说,为什么不工作)。到目前为止,StackOverflow上的文档和各种问题都没有帮助。
主要思想是使用trueY
分别训练张量while_loop
的不同列。问题是当我跟踪这段代码时,我发现while_loop
只被调用一次。
我想动态地为while_loop
中创建的变量指定名称,以便能够在创建while_loop
之后访问它们(因此“gen_name”函数尝试为每个循环中创建的密集层动态生成名称,并使tf.while_loop
以这种方式运行n次。
以下是我的问题代码示例(不是完整代码并经过修改以证明此问题)
...................
config['dim_y'] = 10
Xl = tf.placeholder( self.dtype, shape=(batchsize, config['dim_x']) )
Yl = tf.placeholder( self.dtype, shape=(batchsize, config['dim_y']) )
Gl = tf.placeholder( self.dtype, shape=(batchsize, config['dim_g']) )
costl, cost_m, self.cost_b = self.__cost( Xl, Yl, Gl, False )
def __eval_cost( self, A, X, Y, G, reuse ):
AGXY = tf.concat( [A, G, X, Y], -1 )
Z, mu_phi3, ls_phi3 = build_nn( AGXY, ...., reuse )
_cost = -tf.reduce_sum( ls_phi3, -1 )
_cost += .5 * tf.reduce_sum( tf.pow( mu_phi3, 2 ), -1 )
_cost += .5 * tf.reduce_sum( tf.exp( 2*ls_phi3 ), -1 )
return _cost
def __cost( self, trueX, trueY, trueG, reuse ):
........
columns = tf.unstack(trueY, axis=-1)
AGX = tf.concat( [ AX, G ], -1 )
pre_Y = self.build_nn( AGX, ....., reuse )
index_loop = (tf.constant(0), _cost, _cost_bl)
def condition(index, _cost, _cost_supervised_bi_label):
return tf.less(index, self.config['dim_y'])
def bodylabeled(index, _cost, _cost_bl):
def gen_name(var_name):
# split eg 'cost/while/strided_slice_5:0' => '5'
# split eg 'cost/while/strided_slice:0' => 'slice'
iter = var_name.split('/')[-1].split(':')[0].split('_')[-1]
if iter == "slice":
return '0phi2y'
else:
return '{}phi2y'.format(int(iter) % self.config['dim_y'])
y_i = tf.gather(columns, index)
y = tf.expand_dims( tf.one_hot(tf.to_int32(y_i, name='ToInt32'), depth, dtype=self.dtype ), 0 )
Y = tf.tile( y, [self.config['L'],1,1] )
c = tf.constant(0, name='test')
log_pred_Y = tf.layers.dense( pre_Y, 2, name=gen_name(iter[index].name), reuse=reuse )
log_pred_Y = log_pred_Y - tf.reduce_logsumexp( log_pred_Y, -1, keep_dims=True )
_cost += self.__eval_cost_given_axgy( A, X, Y, G, reuse=tf.AUTO_REUSE )
_cost_bl += -tf.reduce_sum( tf.multiply( Y, log_pred_Y ), -1 )
return tf.add(index, 1), _cost, _cost_supervised_bi_label
_cost, _bl = tf.while_loop(condition, bodylabeled, index_loop, parallel_iterations=1, shape_invariants=(index_loop[0].get_shape(), tf.TensorShape([None, 100]), tf.TensorShape([None, 100])))[1:]
op = costl + cost_m + cost_b
with tf.Session(config=config) as sess:
sess.run( tf.global_variables_initializer() )
sess.run(tf.local_variables_initializer())
for batchl in batches:
sess.run( op,
feed_dict={Xl:Xl[batchl,:],
Yl:Yl[batchl,:].toarray(),
Gl:Gl[batchl,:].toarray(),
is_training:True } )
for n in tf.get_default_graph().as_graph_def().node:
print(n.name)