在tensorflow的csv文件阅读器中是否有num_epochs限制了string_input_producer()?

时间:2017-08-17 01:39:00

标签: csv tensorflow

我有一个虚拟csv文件(y=-x+1

x,y
1,0
2,-1
3,-2

我尝试将其提供给线性回归模型。由于我只有这么少的例子,我想在该文件上重复训练1000次,所以我设置了num_epochs=1000

然而,似乎Tensorflow限制了这个数字。如果我使用num_epochs = 5或10,它可以正常工作,但超过33时,它的上限为33个时期。是真的还是我做错了什么?

# model = W*x+b
... 
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

# reading input from csv
filename_queue = tf.train.string_input_producer(["/tmp/testinput.csv"], num_epochs=1000)
reader = tf.TextLineReader(skip_header_lines=1)
...
col_x, col_label = tf.decode_csv(csv_row, record_defaults=record_defaults)

with tf.Session() as sess:
  sess.run(tf.local_variables_initializer())
  sess.run(tf.global_variables_initializer())
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  while True:
    try:
      input_x, input_y = sess.run([col_x, col_label])
      sess.run(train, feed_dict={x:input_x, y:input_y})
...

附带问题,我是否需要这样做:

input_x, input_y = sess.run([col_x, col_label])
sess.run(train, feed_dict={x:input_x, y:input_y})

我直接尝试sess.run(train, feed_dict={x:col_x, y:col_y})以避免摩擦,但它不起作用(它们是节点,feed_dict需要常规数据)

1 个答案:

答案 0 :(得分:0)

以下代码段与您的输入完美配合:

import tensorflow as tf


filename_queue = tf.train.string_input_producer(["/tmp/input.csv"], num_epochs=1000)
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
col_x, col_label = tf.decode_csv(csv_row, record_defaults=[[0], [0]])

with tf.Session() as sess:
  sess.run(tf.local_variables_initializer())
  sess.run(tf.global_variables_initializer())
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  num = 0
  try:
      while True:
        sess.run([col_x, col_label])
        num += 1
  except:
    print(num)

其中给出了以下输出:

edb@lapelidb:/tmp$ python csv.py
3000