我正在尝试一本书的简单例子,其中我有892行的火车数据样本,这是通常的泰坦生存教科书示例。我定义:
def read_csv(batch_size, file_path, record_defaults):
filename_queue = tf.train.string_input_producer([file_path])
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename_queue)
# decode_csv will convert a Tensor from type string (the text line) in
# a tuple of tensor columns with the specified defaults, which also
# sets the data type for each column
decoded = tf.decode_csv(value, record_defaults=record_defaults)
# batch actually reads the file and loads "batch_size" rows in a single tensor
return tf.train.shuffle_batch(decoded,
batch_size=batch_size,
capacity=batch_size * 50,
min_after_dequeue=batch_size)
def inputs():
passenger_id, survived, pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked = \
read_csv(BATCH_SIZE, file_path, record_defaults)
# convert categorical data
is_first_class = tf.to_float(tf.equal(pclass, [1]))
is_second_class = tf.to_float(tf.equal(pclass, [2]))
is_third_class = tf.to_float(tf.equal(pclass, [3]))
gender = tf.to_float(tf.equal(sex, ["female"]))
# Finally we pack all the features in a single matrix;
# We then transpose to have a matrix with one example per row and one feature per column.
features = tf.transpose(tf.pack([is_first_class, is_second_class, is_third_class, gender, age]))
print 'shape of features', features.get_shape()
return features, survived
现在我尝试做:
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
W = tf.Variable(tf.zeros([5, 1]), name="weights")
b = tf.Variable(0., name="bias")
tf.global_variables_initializer().run()
print 'tf was run'
X,Y = inputs()
print 'inputs!'
sess.run(Y)
我看到了
'tf was run!'
'inputs!'
但是run
部分永远挂起(或至少是一个非常长的结果)。我在Jupyter上使用2.7内核和tf
版0.12
我错过了什么?
答案 0 :(得分:3)
在
行return tf.train.shuffle_batch(decoded,
batch_size=batch_size,
capacity=batch_size * 50,
min_after_dequeue=batch_size)
您定义从队列中提取值和创建批次的操作。
如果你查看方法的完整签名,你会发现有一个参数引用了许多线程。
tf.train.shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None, enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None, name=None)
我指出了这一点,因为您定义的操作是从某些线程执行的。
线程必须启动和停止,此功能不会为您执行此操作。关于线程处理,这个函数唯一做的就是将num_thread
添加到队列中。
实际上,要启动和停止需要在会话中定义一个唤醒队列中线程的操作的线程:
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
W = tf.Variable(tf.zeros([5, 1]), name="weights")
b = tf.Variable(0., name="bias")
tf.global_variables_initializer().run()
print 'tf was run'
X,Y = inputs()
# define a coordinator to start and stop the threads
coord = tf.train.Coordinator()
# wake up the threads
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print 'inputs!'
sess.run(Y) #execute operation
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)