在时期之后CNN模型不会更新:Tensorflow

时间:2017-11-16 11:37:31

标签: validation tensorflow training-data

我正在构建用于源分离的CNN模型。训练数据作为批次馈送到网络中,并且每个时期计算验证损失。但是,我可以看到验证损失保持不变。我怀疑模型在每个时代之后都没有更新,但不确定我是如何看到训练损失正在改变并且每个批次都运行优化器。这是培训功能:

def train():
    # Model
    model = Model()

    # Loss, Optimizer
     global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

    loss_fn = model.loss()
    optimizer =  tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)

    with tf.Session(config=TrainConfig.session_conf) as sess:

        # Initialized, Load state
        sess.run(tf.global_variables_initializer())

        # Input source
        data = Data(TrainConfig.DATA_PATH)

        loss = Diff()
        mixed_wav, src1_wav, src2_wav, numBatch = data.load_wavs()#load all train wavs
        data = Data(ValidConfig.DATA_PATH)
        mixed_val_wav, src1_val_wav, src2_val_wav, numBatch2 = data.load_wavs()#give path argument, change in data.py

        for step in range(global_step.eval(), TrainConfig.FINAL_STEP):#epoch
            batch = 0

            #code for shuffling the data

            for batch in range(0, numBatch):#batch         

                #data is formatted to batchwise
                l, _, _ = sess.run([loss_fn, optimizer, _],
                                     feed_dict={model.x_mixed: mixed_batch, model.x_mixed_spec: mixed_batch_spec, model.y_src1: src1_batch, model.y_src2: src2_batch}) #needed to give mixed and original spectra and also mixed signal to the model

                loss.update(l)
                print('Train losses: step={}\td_loss={:2.2f}\tloss={}'.format(batch, loss.diff * 100, loss.value))

            #after the epoch, validation data formatting
            #calculation of validation loss
            lval = sess.run(loss_fn, feed_dict={model.x_mixed: mixed_val_batch, model.x_mixed_spec: mixed_val_batch_spec, model.y_src1: src1_val_batch, model.y_src2: src2_val_batch})


            print('validation loss epoch={}\tloss={}'.format(step, lval))        

这是输出

火车损失:步骤= 0 d_loss = 0.00损失= 5536927.0

火车损失:步骤= 1 d_loss = -17.21损失= 4583803.5

火车损失:步骤= 2 d_loss = 11.89损失= 5128806.5

火车损失:步数= 3 d_loss = 29.56损失= 6644758.5

火车损失:步数= 569 d_loss = -28.71损失= 5554282.0

验证损失epoch = 0 loss = 67791024.0

火车损失:步数= 0 d_loss = 4.14损失= 5784481.5

火车损失:步骤= 1 d_loss = -13.78损失= 4987227.0

火车损失:步骤= 568 d_loss = 46.64损失= 6187295.0

火车损失:步数= 569 d_loss = -15.51损失= 5227377.0

验证损失epoch = 1 loss = 67791024.0

火车损失:步数= 0 d_loss = -1.49损失= 5149379.0

火车损失:步骤= 1 d_loss = -17.81损失= 4232363.0

如果有人可以提供帮助,那会很棒。提前谢谢。

0 个答案:

没有答案