我最近正在学习机器学习,并尝试使用Tensorflow实现一个简单的神经网络。
我使用MNIST作为数据集,我想使用Tensorflow的Dataset API加载和批处理我的数据。
这是我的代码:
train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train)) train_data = train_data.shuffle(500) train_data = train_data.batch(50) train_data = train_data.repeat() td_iter = train_data.make_one_shot_iterator() features, labels = td_iter.get_next() with tf.Session() as sess: sess.run(init) for epoch in range(n_epochs): for iteration in range(n_batches): X_batch, y_batch = sess.run([features, labels]) sess.run(training_op, feed_dict={X:X_batch, y:y_batch}) acc_train = accuracy.eval(feed_dict={X:X_batch, y:y_batch}) acc_test = accuracy.eval(feed_dict={X:X_test, y:y_test}) print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
我能够高精度地训练模型,但是训练时它会消耗我的所有内存(8GB)。
更具体地说,它在完成第一个时期之前会消耗大量内存(并且打印第一条输出行需要花费相当长的时间),但是如果开始打印某些内容,则内存消耗会减少。
我尝试简化代码以找出问题所在:
with tf.Session() as sess: sess.run(init) sess.run([features, labels])
上面的代码仍然会耗尽我的全部内存。
我认为我的代码一定有误,您能帮我吗?
谢谢!