从TensorFlow中的文件中读取一个(批处理)矩阵

时间:2016-12-06 15:07:39

标签: tensorflow

我刚开始学习TensorFlow。我想从hdfs中的csv文件读取3x3矩阵并将其与自身相乘。

该文件如下所示:

1,2,3
4,5,6
7,8,9

到目前为止,我可以在TensorFlow tutorial

的帮助下提出以下代码
def read_and_decode(filename_queue):
  reader = tf.TextLineReader()
  key, value = reader.read(filename_queue)

  # Type information and column names based on the decoded CSV.
  record_defaults = [[0.0], [0.0], [0.0]]
  f1,f2,f3 = tf.decode_csv(value, record_defaults=record_defaults)

  # Turn the features back into a tensor.
  features = tf.pack([
    f1,
    f2,
    f3])

  return features

def input_pipeline(filename_queue, batch_size, num_threads):
  example = read_and_decode(filename_queue)
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size

  example_batch = tf.train.batch(
      [example], batch_size=batch_size, capacity=capacity, 
      num_threads=num_threads, allow_smaller_final_batch=True)
  return example_batch


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([FILE], num_epochs=1, shuffle=False)
   batch_size = 1
   num_threads = 4
   #batch = input_pipeline(filename_queue, batch_size, num_threads)
   batch = read_and_decode(filename_queue)
   init_op = tf.local_variables_initializer()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   try:
     while True:
       example = sess.run([batch])
       print(example)
   except tf.errors.OutOfRangeError, e:
     coord.request_stop(e)
   finally:
     coord.request_stop()

   coord.join(threads)

get_all_records('hdfs://default/test.csv')

这将以正确的顺序打印矩阵的每一行。但是,当我通过应用input_pipeline()使用批处理时,结果将不是正确的顺序。

我们也可以在Matrix Market format中阅读该文件。这将删除订单上的约束。

所以我的问题是如何以可扩展的方式将结果行(或批处理)放入矩阵(或批处理矩阵)(即矩阵非常大),以便我可以应用矩阵乘法,如:

result = tf.matmul(Matrix,Matrix)
result = tf.batch_matmul(batched_Matrix,batched_Matrix)

作为问题的扩展:哪一个是最快的解决方案,特别是在分布式执行方面?

感谢您的帮助, 菲利克斯

1 个答案:

答案 0 :(得分:0)

经过一些研究,我终于可以实现一个工作原型了:

def read_and_decode(filename_queue):
  reader = tf.TextLineReader()
  key, value = reader.read(filename_queue)

  # Type information and column names based on the decoded CSV.
  record_defaults = [[0.0], [0.0], [0.0]]
  f1,f2,f3 = tf.decode_csv(value, record_defaults=record_defaults)

  return [f1,f2,f3]

def cond(sequence_len, step):
    return tf.less(step,sequence_len)

def body(sequence_len, step, filename_queue): 
    begin = tf.get_variable("begin",tensor_shape.TensorShape([3, 3]),dtype=tf.float32,initializer=tf.constant_initializer(0))
    begin = tf.scatter_update(begin, step, read_and_decode(filename_queue), use_locking=None)
    tf.get_variable_scope().reuse_variables()

    with tf.control_dependencies([begin]):
        return (sequence_len, step+1)

def get_all_records(FILE):
 with tf.Session() as sess:

    filename_queue = tf.train.string_input_producer([FILE], num_epochs=1, shuffle=False)

    b = lambda sl, st: body(sl,st,filename_queue)

    step = tf.constant(0)
    sequence_len  = tf.constant(3)
    _,step, = tf.while_loop(cond,
                    b,
                    [sequence_len, step], 
                    parallel_iterations=10, 
                    back_prop=True, 
                    swap_memory=False, 
                    name=None)

    begin = tf.get_variable("begin",tensor_shape.TensorShape([3, 3]),dtype=tf.float32)

    with tf.control_dependencies([step]):
      product = tf.matmul(begin, begin)

    init0 = tf.local_variables_initializer()
    sess.run(init0)
    init1 = tf.global_variables_initializer()
    sess.run(init1)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
      print(sess.run([product]))
    except tf.errors.OutOfRangeError, e:
      coord.request_stop(e)
    finally:
      coord.request_stop()

    coord.join(threads)

get_all_records('hdfs://default/data.csv')

这个想法来自:How does the tf.scatter_update() work inside the while_loop() 我想我可以用类似的方式实现批量版本。不过,我很高兴任何建议,使其更高效。