我是TensorFlow的新手,我正在使用多个CSV文件作为输入,根据Sachin Joglekar的博客(https://codesachin.wordpress.com/2015/11/28/self-organizing-maps-with-googles-tensorflow/)培训SOM模型。我按照https://www.tensorflow.org/programmers_guide/reading_data上的教程来读取小批量中的CSV文件队列。我的代码正在运行,但我想从阅读器打印出解码的CSV输入,以验证输入管道是否正常工作。由于CSV文件输入不是图表的一部分,我无法使用Tensor.eval(self.sess)打印它。当我尝试使用self.label.eval(session = tf.Session(graph = self.label.graph))打印出解码的记录标签时,我的脚本挂起并且不提供任何输出。有没有办法让我验证我的输入管道是否正常工作?以下是我的代码的相关摘要:
主要功能
def main(argv):
som = SOM(somDim1, somDim2, windowSizes[win], iterations, learningRate,
neighborhood, fileNameList, batchSize)
som.train(batchSize, fileNameList, windowSizes[win])
图形
def __init__(self, m, n, dim, iterations, alpha, sigma, fileNameList, batchSize):
##INITIALIZE GRAPH
self.graph = tf.Graph()
##POPULATE GRAPH WITH NECESSARY COMPONENTS
with self.graph.as_default():
##PLACEHOLDERS FOR TRAINING INPUTS
#These should be placeholders according to the TensorFlow framework,
#but we are declaring them as variables so that we can assign them directly
#to values read in from the CSV files.
batchInputLg = np.zeros((dim, batchSize))
labelFloat = np.zeros((3, batchSize))
self.label = tf.cast(labelFloat, "string")
self.batchInput = tf.cast(batchInputLg, "float32")
"""
...the rest of the graph...
"""
self.trainingOp = tf.assign(self.weightageVects, newWeightagesOp)
##INITIALIZE SESSION
self.sess = tf.Session()
##INITIALIZE VARIABLES
initOp = tf.global_variables_initializer()
self.sess.run(initOp)
输入管道功能
"""
Read in the features and metadata from the CSV files for each chromosome.
"""
def readFromCsv(self, fileNameQ, dim):
reader = tf.TextLineReader()
_, csvLine = reader.read(fileNameQ)
recordDefaults = [["\0"] for cl in range(dim - 1)]
recordStr = tf.decode_csv(csvLine, record_defaults=recordDefaults)
self.label = tf.stack(recordStr[0:2])
#self.label.eval(session = tf.Session(graph=self.label.graph))
self.features = tf.to_float(tf.stack(recordStr[3:dim - 1]))
return (self.features, self.label)
"""
Read in the features and metadata from the CSV files for each chromosome.
"""
def inputPipeline(self, batchSize, fileNameList, dim, num_epochs=None):
fileNameQ = tf.train.string_input_producer(fileNameList, shuffle = True)
minAfterDequeue = 10000
capacity = minAfterDequeue + 3 * batchSize
example, label = self.readFromCsv(fileNameQ, dim)
exampleBatchStr, labelBatch = tf.train.shuffle_batch([example, label], batch_size=batchSize, capacity=capacity, min_after_dequeue=minAfterDequeue)
exampleBatch = tf.cast(exampleBatchStr, "float")
return (exampleBatch, labelBatch)
培训功能
def train(self, batchSize, fileNameList, dim):
#Start the queue runners.
# Start input enqueue threads.
coordFile = tf.train.Coordinator()
self.coord = tf.train.Coordinator()
threadsFile = tf.train.start_queue_runners(sess=self.sess, coord=coordFile)
self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
#Training iterations
self.iterationInput = 0
try:
for iter in range(self.iterations):
#Train with each vector one by one
self.iterationInput += 1
while not self.coord.should_stop():
#Fill in input data.
[self.batchInput, self.label] = self.inputPipeline(batchSize, fileNameList, dim)
self.sess.run(self.trainingOp)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
# When done, ask the threads to stop.
self.coord.request_stop()
答案 0 :(得分:1)
我找到了解决方案。我没有在图表中初始化标签和批量输入张量并在train()函数中分配它们,我应该将赋值语句放在图形中,如下所示:
##TRAINING INPUTS
self.batchInput, self.label = self.inputPipeline(batchSize, fileNameList, dim)
然后,列车功能变为:
def train(self, batchSize, fileNameList, dim):
with self.sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#Training iterations
self.iterationInput = 0
try:
for iter in range(self.iterations):
#Train with each vector one by one
self.iterationInput += 1
while not coord.should_stop():
#Fill in input data.
self.sess.run([self.batchInput, self.label])
self.sess.run(self.trainingOp)
print self.label.eval(session = self.sess)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
# When done, ask the threads to stop.
coord.request_stop()
coord.join(threads)