几步火车零损失?

时间:2019-03-10 17:28:10

标签: python tensorflow

我已经实现了带有Tensorflow的CNN模型,从第二步开始我得到的火车损失为零或几乎为零,而我的验证损失从接近零到10不等。我尝试更改批次大小,学习率,添加辍学层,删除层以减小网络的大小,但是这种情况一直存在。我检查了输入数据,发现还可以。有什么可能是错误的?

这是我的火车模型操作代码:

def define_train_operations(self):
    # set placeholders
    self.keep_prob = tf.placeholder(dtype=tf.float32,name='keep_prob')

    self.X_train   = tf.placeholder(dtype=tf.float32, shape=(None,self.height,self.width,self.chan),name='X_train')

    self.Y_train   = tf.placeholder(dtype=tf.int32,shape=(None, ), name='Y_train')

    # network prediction
    Y_train_predict = self.model_architecture(self.X_train,self.keep_prob,reuse=False)
    # calculate training loss between real label and predicted
    self.train_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=Y_train_predict, labels=self.Y_train,name='train_loss'))

    # define learning rate decay method
    global_step = tf.Variable(0, trainable=False, name='global_step')
    # Define it--play with this
    learning_rate = 0.1

    # define the optimization algorithm
    # Define it --shall we try different type of optimizers
    optimizer = tf.train.AdamOptimizer(learning_rate)

    trainable = tf.trainable_variables()  # may be the weights??
    self.update_ops = optimizer.minimize(self.train_loss, var_list=trainable, global_step=global_step)

    # --- Validation computations
    self.X_valid = tf.placeholder(dtype=tf.float32, shape=(None, self.height, self.width, self.chan))  # Define this
    self.Y_valid = tf.placeholder(dtype=tf.int32, shape=(None, ))  # Define this

    Y_valid_predict = self.model_architecture(self.X_valid,self.keep_prob,reuse=True)

    # Loss on validation
    self.valid_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=Y_valid_predict, labels=self.Y_valid,name='valid_loss'))

# define train actions per epoch
def train_epoch(self,sess):
    print("Train_epoch")
    train_loss = 0
    total_batches = 0
    # print("TOTAL="+str(self.train_size))
    n_batches = self.train_size / self.batch_size  # ??
    indx=0
    X,Y=mf.shuffling(self.Xtrain_in,self.Ytrain_in)  # shuffle X ,Y data
    Xbatch,Ybatch,indx=mf.read_nxt_batch(X,Y,self.batch_size,indx)    # take the right batch
    while Xbatch is not None:     # loop through train batches:
        # print("Ybatch=")
        # print(Ybatch.shape)
        mean_loss, _ = sess.run([self.train_loss, self.update_ops], feed_dict={self.X_train: Xbatch ,self.Y_train: Ybatch,self.keep_prob:0.3})
        Xbatch,Ybatch,indx=mf.read_nxt_batch(X,Y,self.batch_size,indx)
        if math.isnan(mean_loss):
            print('train cost is NaN')
            break
        train_loss += mean_loss
        total_batches += 1

    if total_batches > 0:
        train_loss /= total_batches

    return train_loss

# validation actions per epoch
def valid_epoch(self,sess):
    print("Valid_epoch")
    valid_loss = 0
    total_batches = 0
    n_batches = self.dev_size / self.batch_size  # number of elements
    indx=0
    X,Y=mf.shuffling(self.Xvalid_in,self.Yvalid_in)  # shuffle X ,Y data
    Xbatch,Ybatch,indx=mf.read_nxt_batch(X,Y,self.batch_size,indx)    # take the right batch

    # Loop through valid batches:
    while Xbatch is not None  :

        mean_loss = sess.run(self.valid_loss, feed_dict={self.X_valid: Xbatch,self.Y_valid: Ybatch,self.keep_prob:1.0})
        Xbatch,Ybatch,indx=mf.read_nxt_batch(X,Y,self.batch_size,indx)

        if math.isnan(mean_loss):
            print('valid cost is NaN')
            break
        valid_loss += mean_loss
        total_batches += 1

    if total_batches > 0:
        valid_loss /= total_batches

    return valid_loss


def train(self,sess,iter):
    start_time = time.clock()

    n_early_stop_epochs = 10  # Define it
    n_epochs = 30  # Define it

    # restore variables from previous train session
    if(iter>0): restore_variables(sess)

    # create saver object
    saver = tf.train.Saver(var_list = tf.trainable_variables(), max_to_keep = 4)

    early_stop_counter=0

    # initialize train variables
    init_op = tf.group(tf.global_variables_initializer())

    sess.run(init_op)

    # assign a large value to min
    min_valid_loss = sys.float_info.max
    epoch=0

    # loop for a given number of epochs
    while (epoch < n_epochs): # max num epoch iteration
        epoch += 1
        epoch_start_time = time.clock()

        train_loss = self.train_epoch(sess)
        valid_loss = self.valid_epoch(sess)
        # print("valid ends")
        epoch_end_time=time.clock()

        info_str='Epoch='+str(epoch) + ', Train: ' + str(train_loss) + ', Valid: '
        info_str += str(valid_loss) + ', Time=' +str(epoch_end_time - epoch_start_time)
        print(info_str)

        if valid_loss < min_valid_loss:
            print('Best epoch=' + str(epoch))
            save_variables(sess, saver, epoch, self.model_id)
            min_valid_loss=valid_loss
            early_stop_counter=0
        else:
            early_stop_counter += 1

        # stop training when overfiiting conditon is true
        if early_stop_counter > n_early_stop_epochs:
            # too many consecutive epochs without surpassing the best model
            print('stopping early')
            break
    end_time=time.clock()
    print('Total time = ' + str(end_time - start_time))

这是执行火车的电话:

# change this according to your path
    path_to_train_set = "../Train_set"
    path_to_valid_set = "../Validation"

    model_id = get_model_id()
    # model_id = read_model_id
    n_tfiles=550 # how many train files will read
    n_vfiles=round(0.25*n_tfiles)
    # print("a= \n")
    # print(n_vfiles)
    # cheat count files number
    total_inp_files = len(os.listdir(path_to_train_set))


    # Create the network
    network = CNN(model_id)
    iter=0
    for i in range(1,total_inp_files,n_tfiles):
    # loop until all data are read

        mf.input(network,n_tfiles,n_vfiles)

        with tf.device('/gpu:0'):
            # restore()
            if(iter==0):
                # Define the train computation graph
                network.define_train_operations()


            # Train the network
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement = True)) # session with log about gpu exec
            #sess= tf.Session()
            try:
                print(iter)
                network.train(sess,iter)
                iter += 1
                flag = 0
                # save()
            except KeyboardInterrupt:
                print()
                flag = 1
            finally:
                flag = 1
            sess.close()

0 个答案:

没有答案