我刚开始学习TensorFlow。我想从hdfs中的csv文件读取3x3矩阵并将其与自身相乘。
该文件如下所示:
1,2,3
4,5,6
7,8,9
到目前为止,我可以在TensorFlow tutorial:
的帮助下提出以下代码def read_and_decode(filename_queue):
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Type information and column names based on the decoded CSV.
record_defaults = [[0.0], [0.0], [0.0]]
f1,f2,f3 = tf.decode_csv(value, record_defaults=record_defaults)
# Turn the features back into a tensor.
features = tf.pack([
f1,
f2,
f3])
return features
def input_pipeline(filename_queue, batch_size, num_threads):
example = read_and_decode(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch = tf.train.batch(
[example], batch_size=batch_size, capacity=capacity,
num_threads=num_threads, allow_smaller_final_batch=True)
return example_batch
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([FILE], num_epochs=1, shuffle=False)
batch_size = 1
num_threads = 4
#batch = input_pipeline(filename_queue, batch_size, num_threads)
batch = read_and_decode(filename_queue)
init_op = tf.local_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while True:
example = sess.run([batch])
print(example)
except tf.errors.OutOfRangeError, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
get_all_records('hdfs://default/test.csv')
这将以正确的顺序打印矩阵的每一行。但是,当我通过应用input_pipeline()使用批处理时,结果将不是正确的顺序。
我们也可以在Matrix Market format中阅读该文件。这将删除订单上的约束。
所以我的问题是如何以可扩展的方式将结果行(或批处理)放入矩阵(或批处理矩阵)(即矩阵非常大),以便我可以应用矩阵乘法,如:
result = tf.matmul(Matrix,Matrix)
result = tf.batch_matmul(batched_Matrix,batched_Matrix)
作为问题的扩展:哪一个是最快的解决方案,特别是在分布式执行方面?
感谢您的帮助, 菲利克斯
答案 0 :(得分:0)
def read_and_decode(filename_queue):
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Type information and column names based on the decoded CSV.
record_defaults = [[0.0], [0.0], [0.0]]
f1,f2,f3 = tf.decode_csv(value, record_defaults=record_defaults)
return [f1,f2,f3]
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step, filename_queue):
begin = tf.get_variable("begin",tensor_shape.TensorShape([3, 3]),dtype=tf.float32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, step, read_and_decode(filename_queue), use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([FILE], num_epochs=1, shuffle=False)
b = lambda sl, st: body(sl,st,filename_queue)
step = tf.constant(0)
sequence_len = tf.constant(3)
_,step, = tf.while_loop(cond,
b,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",tensor_shape.TensorShape([3, 3]),dtype=tf.float32)
with tf.control_dependencies([step]):
product = tf.matmul(begin, begin)
init0 = tf.local_variables_initializer()
sess.run(init0)
init1 = tf.global_variables_initializer()
sess.run(init1)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
print(sess.run([product]))
except tf.errors.OutOfRangeError, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
get_all_records('hdfs://default/data.csv')
这个想法来自:How does the tf.scatter_update() work inside the while_loop() 我想我可以用类似的方式实现批量版本。不过,我很高兴任何建议,使其更高效。