我正在训练无监督的CNN,为此,我定义了一个损失函数,其中CNN输入通过复杂的表达式映射到cnn输出。我正在使用tf.while_loop
更新gamma_tilde_tensor_local
和D_mat_tensor_local
的变量。我已经上传了我的代码和下面的错误。
请建议我如何更新while_loop
中的张量变量。
为避免混淆,我避免了复杂的循环。最初的大小为gamma_tilde_tensor_local
,大小为(batch_size, M,M,K,K)
。
def function(self):
gamma_tilde_tensor_local = tf.Variable(tf.zeros(shape = [self.mini_batchSize, self.num_users]),dtype=tf.float32)
D_mat_tensor_local = tf.Variable(tf.zeros(shape = [self.mini_batchSize, , self.num_users]), dtype=tf.float32)
condition_g1 = lambda k_iter1, gamma_tilde_tensor_local, D_mat_tensor_local : k_iter1 < self.num_users
def body_gm1( k_iter1, gamma_tilde_tensor_local, D_mat_tensor_local):
tf_x_k = tf.expand_dims(self.PHI_batch[:,:,k_iter1],axis=-1)
tf_x_kd = tf.expand_dims(self.PHI_batch[:,:,k_iter2],axis=-1)
phi_to_reduce = tf.matmul( tf.transpose(tf_x_kd, perm=[0, 2, 1]), tf_x_k)
phi_t = tf.squeeze(phi_to_reduce,[1,2])
gamma_till = tf.divide( tf.multiply( self.gammafun[:,1, 2 ], self.channel_Gain[:,1, k_iter1]), self.channel_Gain[:,1, 2])
gamma_tilde_tensor_local[: k_iter1] = tf.multiply(phi_t, gamma_till )
D_mat_tensor_local[:,k_iter1] = tf.sqrt( tf.multiply( self.gammafun[1, 2], self.channel_Gain[1,k_iter1] ))
return tf.add(k_iter1, 1), gamma_tilde_tensor_local, D_mat_tensor_local
k_iter1, gamma_tilde_tensor_local, D_mat_tensor_local = tf.while_loop(condition_g1, body_gm1, [0, gamma_tilde_tensor_local, D_mat_tensor_local])
在行上给出错误:
gamma_tilde_tensor_local[: k_iter1] = tf.multiply(phi_t, gamma_till )
。