Tensorflow:如何在模型中添加正则化

时间:2018-12-19 06:42:42

标签: python tensorflow deep-learning

我想像这样在我的优化器中添加正则化:

tf.train.AdadeltaOptimizer(learning_rate=1).minimize(loss)

但是我不知道如何在下面的代码中设计“损失”功能

我看到的网站是: https://blog.csdn.net/marsjhao/article/details/72630147

修改后的代码最初来自Google机器学习课程: https://colab.research.google.com/notebooks/mlcc/improving_neural_net_performance.ipynb?utm_source=mlcc&utm_campaign=colab-external&utm_medium=referral&utm_content=improvingneuralnets-colab&hl=zh-tw#scrollTo=P8BLQ7T71JWd

有人可以给我一些建议或与我讨论吗?


def train_nn_classifier_model_new(
    my_optimizer,
    steps,
    batch_size,
    hidden_units,
    training_examples,
    training_targets,
    validation_examples,
    validation_targets):

  periods = 10
  steps_per_period = steps / periods

  # Create a DNNClassifier object.

  my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)
  dnn_classifier = tf.estimator.DNNClassifier(
      feature_columns=construct_feature_columns(training_examples),
      hidden_units=hidden_units,
      optimizer=my_optimizer
      )

  # Create input functions.
  training_input_fn = lambda: my_input_fn(training_examples, 
                                          training_targets["deal_or_not"], 
                                          batch_size=batch_size)
  predict_training_input_fn = lambda: my_input_fn(training_examples,        
                                         training_targets["deal_or_not"], 
                                         num_epochs=1, 
                                         shuffle=False)
  predict_validation_input_fn = lambda: my_input_fn(validation_examples, 
                                         validation_targets["deal_or_not"], 
                                         num_epochs=1, 
                                         shuffle=False)
  # Train the model, but do so inside a loop so that we can periodically assess
  # loss metrics.
  print("Training model...")
  print("LogLoss (on training data):")
  training_log_losses = []
  validation_log_losses = []
  for period in range (0, periods):
    # Train the model, starting from the prior state.
    dnn_classifier.train(
        input_fn=training_input_fn,
        steps=steps_per_period
    )
    # Take a break and compute predictions.    
    training_probabilities = 
    dnn_classifier.predict(input_fn=predict_training_input_fn)
    training_probabilities = np.array([item['probabilities'] for item in training_probabilities])
    print(training_probabilities)

    validation_probabilities = dnn_classifier.predict(input_fn=predict_validation_input_fn)
    validation_probabilities = np.array([item['probabilities'] for item in validation_probabilities])

    training_log_loss = metrics.log_loss(training_targets, training_probabilities)
    validation_log_loss = metrics.log_loss(validation_targets, validation_probabilities)
    # Occasionally print the current loss.
    print("  period %02d : %0.2f" % (period, training_log_loss))
    # Add the loss metrics from this period to our list.
    training_log_losses.append(training_log_loss)
    validation_log_losses.append(validation_log_loss)
  print("Model training finished.")

  # Output a graph of loss metrics over periods.
  plt.ylabel("LogLoss")
  plt.xlabel("Periods")
  plt.title("LogLoss vs. Periods")
  plt.tight_layout()
  plt.plot(training_log_losses, label="training")
  plt.plot(validation_log_losses, label="validation")
  plt.legend()

  return dnn_classifier




result = train_nn_classifier_model_new(
    my_optimizer=tf.train.AdadeltaOptimizer (learning_rate=1),
    steps=30000,
    batch_size=250,
    hidden_units=[150, 150, 150, 150],
    training_examples=training_examples,
    training_targets=training_targets,
    validation_examples=validation_examples,
    validation_targets=validation_targets
    )

1 个答案:

答案 0 :(得分:1)

正则化被添加到损失函数中。您的优化器AdadeltaOptimizer不支持正则化参数。如果要向优化器添加正则化,则应使用tf.train.ProximalAdagradOptimizer,因为它具有l2_regularization_strengthl1_regularization_strength参数,您可以在其中设置值。这些参数是原始算法的一部分。

否则,您只需要对自定义损失函数应用正则化,但是DNNClassifier不允许使用任何自定义损失函数。您必须为此手动创建网络。 如何添加正则化,请选中here