tf.contrib.learn.Estimator类是否使用所有数据?

时间:2017-07-05 05:47:48

标签: python tensorflow

我有一个训练操作:

def train(self, batch_size, steps):
    x, y = self.generate(batch_size*steps)
    print('Training...')
    self.classifier.fit(x=x, y=y, batch_size=batch_size, steps=steps)

分类器在这里定义:

self.classifier = learn.Estimator(model_fn=self.model, model_dir=self.dir)

我的问题是 - 如果xy的尺寸大于batch_size,那么在steps移动时它们是否都会被使用?例如,如果batch_size为128,但xy都是128,000项,那么在steps达到1000步之前,是否会对所有项目进行培训?

我问这个是因为generate函数需要很长时间,而且我想知道如果实际情况只有第一个{{1}实际上是浪费了大部分时间使用它的例子。

注意:我知道batch_sizex参数已弃用,我应该使用y,因此问题适用于这两种方式,例如,如果训练操作是这样的:

input_fn

换句话说,def train(self, batch_size, steps): self.classifier.fit(input_fn=lambda: self.generate(batch_size*steps), steps=steps) 函数或生成x,y张量的函数应该在需要input_fn数据示例或batch_size*steps的情况下调用,因为只有那会被处理吗?

1 个答案:

答案 0 :(得分:1)

如果您的batch_size为128,如果您有128000项,那么当steps达到1000步时,所有项目都会接受培训。 estimator仅针对每个batch_size提取您在training step中描述的内容。

我写了一段代码,它读取输入(每个样本只有1个),每个训练步骤总结了它到那时看到的那些,它告诉你它已读取了多少数据样本

from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib

tf.logging.set_verbosity(tf.logging.INFO)

def model_fn(features, labels, mode):

   _sum = tf.Variable(0, dtype=tf.int32)   

   if mode == learn.ModeKeys.TRAIN:
       # Update global_step
       global_step=tf.contrib.framework.get_global_step()
       global_step_op = tf.assign(global_step, global_step+1)

       # Sum of all the elements in a batch
       update_sum_op = tf.assign_add(_sum, tf.reduce_sum(features)) 
       update_op = tf.group(global_step_op, update_sum_op)
       loss = _sum

   predictions = {'out': tf.identity(_sum, 'sum')}

   return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions, loss=loss, train_op=update_op)


X = np.ones((1000, 1), dtype=np.int32)
y = np.ones((1000, 1), dtype=np.int32)

sess = tf.InteractiveSession()

feature_classifier = learn.SKCompat(learn.Estimator(model_fn=model_fn))
tensors_to_log = {'out':'sum'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=1)
feature_classifier.fit(x=X, y=y, batch_size=123, steps=7, monitors=[logging_hook])

此处的总数据样本为1000,batch_size=123steps=7

每一步的输出是:

INFO:tensorflow:out = 123
INFO:tensorflow:out = 246 (0.004 sec)
INFO:tensorflow:out = 369 (0.003 sec)
INFO:tensorflow:out = 492 (0.003 sec)
INFO:tensorflow:out = 615 (0.003 sec)
INFO:tensorflow:out = 738 (0.003 sec)
INFO:tensorflow:out = 861 (0.003 sec)