如何在tf.learn中使用input_fn进行批量训练?

时间:2017-04-06 03:53:21

标签: tensorflow

model_dir = "no_regulation"
print(model_dir)
m = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns,
    optimizer=tf.train.FtrlOptimizer(
      learning_rate=3,
      l1_regularization_strength=0,
      l2_regularization_strength=0),
    n_classes = n_classes,
    model_dir=model_dir)

def train_input_fn():
  print("Here!")
  return input_fn(train.sample(50000), label_column = "course_index", categorical_columns = CATEGORICAL_COLUMNS)

如果我执行以下操作,则每10个步骤批量处理50000个样本,

for i in range(40):
    for j in range(20):
        m.fit(input_fn=train_input_fn, steps = 10)
    m.evaluate(input_fn=eval_input_fn1, steps = 1, name="test1")
    m.evaluate(input_fn=eval_input_fn2, steps = 1, name="test2")

这合理吗?如果我做m.fit(input_fn = train_input_fn,steps = 1),每次适合调用都会创建一个检查点,这会大大减慢训练速度。我应该禁用检查点吗?如果是这样,怎么样?

1 个答案:

答案 0 :(得分:1)

我找到的方法是使用m.partial_fit代替fitpartial_fit不会触发CheckpointSaverHook

似乎evaluate确实如此。