使用tf.cond()时

时间:2019-06-18 16:50:15

标签: python tensorflow

我想创建一个分支模型架构,以便将训练操作与验证操作分开。我正在尝试使用tf.cond()。

我调用define_graph_op()创建图形,因此当我进行sess.run()时,通过找到具有正确值的占位符只使用它的每个部分-一个用于训练,一个用于验证。

我不知道是什么原因导致了该问题,以及如何解决该问题。

def define_graph_op(self):
        # branch if true get train operations if false get validation op 
        # greater than zero true, 0 false
        self.branch_graph = tf.placeholder(dtype=tf.int32,shape=())
        self.result = tf.cond(self.branch_graph>0,lambda:self.train_operations(),lambda:self.valid_op())

def train_operations(self):
    # set placeholders
    self.keep_prob = tf.placeholder(dtype=tf.float32,name='keep_prob')

    self.X_train   = tf.placeholder(dtype=tf.float32, shape=(None,self.n_input),name='X_train')

    self.Y_train   = tf.placeholder(dtype=tf.int32,shape=(None, self.n_classes), name='Y_train')

    # network TRAIN set prediction                  --|TRAIN|--
    self.Y_train_predict = self.model_architecture(self.X_train,self.keep_prob) #logits for loss
    self.Y_train_soft = tf.nn.softmax(self.Y_train_predict) # softmax for accuracy
    # tf.summary.histogram('soft-act',self.Y_train_soft)

    # calculate LOSS between real label and predicted
    train_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.Y_train_predict, labels=self.Y_train,name='train_loss'))
    tf.summary.scalar('train_loss',train_loss,collections=['Training'])

    # softmax - accuracy
    y_pred = tf.argmax(self.Y_train_soft,axis=1,output_type=tf.int32) # arg max of the predicted output -softmax
    y_correct = tf.argmax(self.Y_train, axis=1, output_type=tf.int32) # arg-max of the actual input --placeholder
    # Cast a boolean tensor to float32
    self.train_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_correct), tf.float32))
    tf.summary.scalar('train_acc',self.train_accuracy,collections=['Training'])


    # define learning rate decay method
    global_step = tf.Variable(0, trainable=False, name='global_step')
    # Define it--play with this
    learning_rate = 0.0001

    # the optimization algorithm
    optimizer = tf.contrib.optimizer_v2.AdamOptimizer(learning_rate,beta1=0.9, beta2=0.999, epsilon=1e-8,name='training_Adam') #tf.train.AdamOptimizer(learning_rate,beta1=0.9, beta2=0.999, epsilon=1e-8,name='training_Adam')
    self.trainable = tf.trainable_variables()  # may be the weights  ??
    self.update_ops = optimizer.minimize(train_loss, var_list=self.trainable, global_step=global_step)

    # # tf.summary.scalar('valid_loss',self.valid_loss)
    self.summary_op_train = tf.summary.merge_all(key='Training')

    return train_loss

def valid_op(self):
    # # --- Validation computations             ---|VALID|-----
    self.X_valid = tf.placeholder(dtype=tf.float32, shape=(None, self.n_input),name='X_valid')  # Define this
    self.Y_valid = tf.placeholder(dtype=tf.int32, shape=(None,self.n_classes),name='Y_valid')  # Define this
    # # logits layer without softmax
    self.Y_valid_predict = self.model_architecture(self.X_valid,self.keep_prob)
    self.Y_valid_soft = tf.nn.softmax(self.Y_valid_predict)
    # tf.summary.histogram('soft-act',self.y_valid_softmax)

    # Loss on validation
    valid_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.Y_valid_predict, labels=self.Y_valid,name='valid_loss'))
    tf.summary.scalar('valid_loss',valid_loss,collections=['VALID'])
    # # valid_accuracy
    y_pred_valid = tf.argmax(self.Y_valid_soft,axis=1,output_type=tf.int32)
    y_correct_valid = tf.argmax(self.Y_valid, axis=1, output_type=tf.int32)
    # # # Cast a boolean tensor to float32
    self.valid_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_valid, y_correct_valid), tf.float32))
    tf.summary.scalar('valid_acc',self.valid_accuracy,collections=['VALID'])
    self.summary_op_valid = tf.summary.merge_all(key='VALID')

    return valid_loss


def train():
    train_l,_ =sess.run([self.result,self.update_ops], feed_dict={self.X_train: Xbatch ,self.Y_train: Ybatch,self.keep_prob:self.dropout,self.branch_graph:1})
TypeError: Fetch argument <function CNN.train_operations.<locals>.<lambda> at 0x7fc158259d08> has invalid type <class 'function'>, must be a string or Tensor.

0 个答案:

没有答案