我有一个使用Tensorflow的MNIST数据集的python代码。 会议如下:
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(int(mnist.train.num_examples / batch_size)):
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
print('Epoch: ', epoch, ' completed out of: ', hm_epochs, ' loss: ', epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Accuracy:', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
该行:
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
每次都会批量生产新的批次。
我的问题是,如果我有自己的CSV文件(列表列表),如何用新行替换这一行,这对我来说是新的批次? 我目前的代码如下:
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(len(training_data_list) // batch_size):
epoch_x, epoch_y = training_data_list.nextbatch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
print('Epoch: ', epoch, ' completed out of: ', hm_epochs, ' loss: ', epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Accuracy:', accuracy.eval({x: inputs, y: targets}))
" nextbatch"是我定义的函数。但是我收到以下错误:
AttributeError: 'list' object has no attribute 'nextbatch'
我感谢任何建议:D
顺便说一下," training_data_list"来自:
stops = open('.../Desktop/stops_train.csv', 'r')
training_data_list = stops.readlines()
stops.close()
答案 0 :(得分:0)
您需要实现一个处理索引的对象。 您需要在该对象中实现nextbatch函数。 您可以在mnist中查看nextbatch的实现。