我正在使用标准输入管道来读取TensorFlow中的CSV文件。我有一个输入文件,我想用迷你批次读取。我担心的是,即使我相信我已正确设置了时代数,该文件仍被多次读取。我认为这可能是由于train.string_input_producer()函数的行为造成的。
with self.graph.as_default():
epochs = np.floor(fileSize / batchSize) + 1
self.fileNameQ = tf.train.string_input_producer(fileNameList, num_epochs = epochs)
self.batchInput, self.label = self.inputPipeline(batchSize, dim)
此处,fileSize由以下内容确定:
fileSize = sum(1 for line in open(file.name))
我已经定义了我的输入管道功能如下:
def readFromCsv(self, dim):
reader = tf.TextLineReader()
_, csvLine = reader.read(self.fileNameQ)
recordDefaults = [["\0"] for cl in range(dim + 3)]
recordStr = tf.decode_csv(csvLine, record_defaults=recordDefaults)
self.label = tf.stack(recordStr[0:3])
self.features = tf.stack(recordStr[3:dim + 3])
return (self.features, self.label)
def inputPipeline(self, batchSize, dim):
minAfterDequeue = 10000
capacity = minAfterDequeue + 3 * batchSize
example, label = self.readFromCsv(dim)
exampleBatchStr, labelBatch = tf.train.batch([example, label], batch_size=batchSize, capacity=capacity)
exampleBatch = tf.string_to_number(exampleBatchStr)
return (tf.transpose(exampleBatch), tf.transpose(labelBatch))
然后我执行培训并在每次处理1,000条记录时打印出一个点。
def train(self, batchSize, dim):
with self.sess:
self.sess.run(tf.local_variables_initializer())
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#Training iterations
self.iterationInput = 0
while self.iterationInput < self.iterations:
#Train with each vector one by one
self.iterationInput += 1
print("iteration " + str(self.iterationInput) + " for window size " + str(dim))
try:
loopCount = 0
while not coord.should_stop():
#Fill in input data.
self.sess.run([self.batchInput, self.label])
self.sess.run(self.trainingOp)
#For every 1,000 samples, print a dot.
if loopCount % 1000 == 0:
sys.stdout.flush()
sys.stdout.write('.')
loopCount += 1
except tf.errors.OutOfRangeError:
print("Done training -- epoch limit reached")
coord.request_stop()
# When done, join the threads
coord.join(threads)
运行程序时打印出的点数明显多于文件中的记录数(/ 1000)。我知道train.string_input_producer()初始化了四个线程。我担心每个线程都在读取文件一次,导致运行时在硬件中没有足够的并行化时增加。由于运行时已经很长,我不想再增加它。有什么办法可以阻止文件被读取四次吗?