我有一个训练操作:
def train(self, batch_size, steps):
x, y = self.generate(batch_size*steps)
print('Training...')
self.classifier.fit(x=x, y=y, batch_size=batch_size, steps=steps)
分类器在这里定义:
self.classifier = learn.Estimator(model_fn=self.model, model_dir=self.dir)
我的问题是 - 如果x
和y
的尺寸大于batch_size
,那么在steps
移动时它们是否都会被使用?例如,如果batch_size
为128,但x
和y
都是128,000项,那么在steps
达到1000步之前,是否会对所有项目进行培训?
我问这个是因为generate
函数需要很长时间,而且我想知道如果实际情况只有第一个{{1}实际上是浪费了大部分时间使用它的例子。
注意:我知道batch_size
和x
参数已弃用,我应该使用y
,因此问题适用于这两种方式,例如,如果训练操作是这样的:
input_fn
换句话说,def train(self, batch_size, steps):
self.classifier.fit(input_fn=lambda: self.generate(batch_size*steps), steps=steps)
函数或生成x,y张量的函数应该在需要input_fn
数据示例或batch_size*steps
的情况下调用,因为只有那会被处理吗?
答案 0 :(得分:1)
如果您的batch_size
为128,如果您有128000项,那么当steps
达到1000步时,所有项目都会接受培训。 estimator
仅针对每个batch_size
提取您在training step
中描述的内容。
我写了一段代码,它读取输入(每个样本只有1个),每个训练步骤总结了它到那时看到的那些,它告诉你它已读取了多少数据样本
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
tf.logging.set_verbosity(tf.logging.INFO)
def model_fn(features, labels, mode):
_sum = tf.Variable(0, dtype=tf.int32)
if mode == learn.ModeKeys.TRAIN:
# Update global_step
global_step=tf.contrib.framework.get_global_step()
global_step_op = tf.assign(global_step, global_step+1)
# Sum of all the elements in a batch
update_sum_op = tf.assign_add(_sum, tf.reduce_sum(features))
update_op = tf.group(global_step_op, update_sum_op)
loss = _sum
predictions = {'out': tf.identity(_sum, 'sum')}
return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions, loss=loss, train_op=update_op)
X = np.ones((1000, 1), dtype=np.int32)
y = np.ones((1000, 1), dtype=np.int32)
sess = tf.InteractiveSession()
feature_classifier = learn.SKCompat(learn.Estimator(model_fn=model_fn))
tensors_to_log = {'out':'sum'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=1)
feature_classifier.fit(x=X, y=y, batch_size=123, steps=7, monitors=[logging_hook])
此处的总数据样本为1000,batch_size=123
和steps=7
。
每一步的输出是:
INFO:tensorflow:out = 123
INFO:tensorflow:out = 246 (0.004 sec)
INFO:tensorflow:out = 369 (0.003 sec)
INFO:tensorflow:out = 492 (0.003 sec)
INFO:tensorflow:out = 615 (0.003 sec)
INFO:tensorflow:out = 738 (0.003 sec)
INFO:tensorflow:out = 861 (0.003 sec)