我有一个关于开发神经网络的相当普遍的问题。它仍然与如何编写模型有关。
在训练模型时,通常的做法是定期计算成本(甚至可能是精确度),以便能够检查进度的趋势。成本可能不是在每个时代之后计算出来的,而是可能每个 - 比如说 - 第100个时期,并保存在某个地方以在最后绘制它的图形。特别是如果数据集非常大(根据我目前所见),那么这个成本只能在小批量而不是整个训练数据集上计算,并且模型的成本是通过取平均数来计算的。所有这些培训都设置了mini_batch成本。
在训练模型时,我想定期计算开销和测试集的成本(甚至精确度),以便稍后能够比较趋势(特别是在训练集和开发集之间)以获得更好的图像该模型是如何工作的。但我在一个while循环中读取csv文件中的列车数据如下:
...
...
try:
while not coord.should_stop():
_ , minibatch_cost = sess.run([optimizer, cost])
nr_of_minibatches += 1
cost_of_model += minibatch_cost
cost_of_model /= nr_of_minibatches
# Print and save the progress (cost, accuracy, etc) periodically
if print_progress == True and nr_of_minibatches % 5 == 0:
print ("Cost after minibatch %i: %f" % (nr_of_minibatches, cost_of_model))
costs.append(cost_of_model)
accuracy_train = accuracy.eval() #(feed_dict={ZL: ZL, Y_mini_batch: Y_mini_batch})
train_accuracies.append(accuracy_train)
print("accuracy_train = " + str(accuracy_train))
test_accuracy = accuracy.eval(feed_dict={X_Y_mini_batch: Y_test})
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)
...
...
只要在火车集csv文件上没有完成纪元的#循环,这个while循环就会循环。如上所述,在每个第5小批量之后,我计算成本。我想为开发和测试数据集添加成本/精度计算。
问题1)从实际的角度来看,我假设开发和测试集数据应该驻留在单独的csv文件中。你同意我的意见吗?如果没有,(如果您认为所有训练/开发/测试集都可以在同一个文件中),那么我应该如何在tensorflow中实现它?我猜Scikit-learn有一个技巧,但我认为训练/开发/测试数据在tensorflow实现中应该是分开的。
问题2)如果train / dev / test数据应该在单独的csv文件中,那么实现应该如何?上面给出的while循环是循环通过列车csv文件。我对如何优雅地实现它有点困惑,以便在每个 - 比如说 - 第5列车时代之后我重新计算开发/测试集的成本/准确性在单独的csv文件中。
答案 0 :(得分:0)
您的try/except/finally
块有错误的缩进。 except
和finally
语句应与try
处于同一缩进级别。
就拆分数据而言,您可以采用任何一种方式执行操作,具体取决于您的特定用例。您可以将所有数据保存在一个大的CSV文件中,并让程序将数据拆分为单独的块以进行培训/开发/测试,或者您可以保留3个单独的CSV文件。唯一需要注意的是,如果您选择将所有数据保存在一个CSV文件中,那么将数据拆分为train / dev / test集的方法应该是可重现的,否则您最终将对您的开发人员进行培训/如果您不止一次训练模型,则意外设置。
如何拆分数据并非您使用的深度学习包所独有。您当然可以使用Scikit-learn的分割数据的方法,然后使用Tensorflow进行训练。没有理由为Tensorflow特定的数据拆分重新发明轮子。
我建议反对的一件事是在培训期间检查测试集的成本/准确性。这违背了测试集的目的。您可以使用开发设置来微调模型中的超参数,但是您应该只使用最后的测试集来查看您的成本/准确度在现实世界中的可能性。 #34 ;.你不想过度适应你的测试集,这就是将数据分成3组的全部要点。
如果您选择将文件分开,则意味着您的循环中将有一个遍历单独文件的循环,并计算单独的开发文件集上的成本/准确性。