转换TensorFlow教程以使用我自己的数据

时间:2017-02-17 16:12:30

标签: python tensorflow

这是我上一期Converting from Pandas dataframe to TensorFlow tensor object

的后续内容

我现在正迈出下一步,需要更多帮助。我试图替换这行代码

batch = mnist.train.next_batch(100)

替换我自己的数据。我在StackOverflow上找到了这个答案:Where does next_batch in the TensorFlow tutorial batch_xs, batch_ys = mnist.train.next_batch(100) come from?但我不明白:

1)为什么.next_batch()不能对我的张量起作用。我是否错误地创建了

2)如何实现.next_batch()

中问题答案中给出的伪代码

我目前有两个张量对象,一个具有我希望用于训练模型的参数(dataVar_tensor)和一个具有正确结果的参数(depth_tensor)。我显然需要保持他们的关系,以正确的参数保持正确的响应。

请您花一些时间帮助我了解正在发生的事情并替换这行代码?

非常感谢

1 个答案:

答案 0 :(得分:2)

我剥离了不相关的东西,以便保留格式和缩进。希望现在应该清楚。以下代码分批读取N行中的CSV文件(在顶部的常量中指定N)。每行包含一个日期(第一个单元格),然后是一个浮点列表(480个单元格)和一个单热矢量(3个单元格)。然后代码只是在读取它们时打印这些日期,浮点数和一个热矢量的批次。它打印它们的地方通常是你实际运行模型的地方,并用它们代替占位符变量。

请记住,此处它将每行读取为String,然后将该行中的特定单元格转换为浮点数,这只是因为第一个单元格更容易作为字符串读取。如果您的所有数据都是数字,那么只需将默认值设置为float / int而不是'a',并删除将字符串转换为浮点数的代码。否则就不需要了!

我发表了一些评论来澄清它在做什么。如果有什么不清楚,请告诉我。

import tensorflow as tf

fileName = 'YOUR_FILE.csv'

try_epochs = 1
batch_size = 3

TD = 1 # this is my date-label for each row, for internal pruposes
TS = 480 # this is the list of features, 480 in this case
TL = 3 # this is one-hot vector of 3 representing the label

# set defaults to something (TF requires defaults for the number of cells you are going to read)
rDefaults = [['a'] for row in range((TD+TS+TL))]

# function that reads the input file, line-by-line
def read_from_csv(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=False) # i have no header file
    _, csv_row = reader.read(filename_queue) # read one line
    data = tf.decode_csv(csv_row, record_defaults=rDefaults) # use defaults for this line (in case of missing data)
    dateLbl = tf.slice(data, [0], [TD]) # first cell is my 'date-label' for internal pruposes
    features = tf.string_to_number(tf.slice(data, [TD], [TS]), tf.float32) # cells 2-480 is the list of features
    label = tf.string_to_number(tf.slice(data, [TD+TS], [TL]), tf.float32) # the remainin 3 cells is the list for one-hot label
    return dateLbl, features, label

# function that packs each read line into batches of specified size
def input_pipeline(fName, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(
        [fName],
        num_epochs=num_epochs,
        shuffle=True)  # this refers to multiple files, not line items within files
    dateLbl, features, label = read_from_csv(filename_queue)
    min_after_dequeue = 10000 # min of where to start loading into memory
    capacity = min_after_dequeue + 3 * batch_size # max of how much to load into memory
    # this packs the above lines into a batch of size you specify:
    dateLbl_batch, feature_batch, label_batch = tf.train.shuffle_batch(
        [dateLbl, features, label], 
        batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    return dateLbl_batch, feature_batch, label_batch

# these are the date label, features, and label:
dateLbl, features, labels = input_pipeline(fileName, batch_size, try_epochs)

with tf.Session() as sess:

    gInit = tf.global_variables_initializer().run()
    lInit = tf.local_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop():
            # load date-label, features, and label:
            dateLbl_batch, feature_batch, label_batch = sess.run([dateLbl, features, labels])      

            print(dateLbl_batch);
            print(feature_batch);
            print(label_batch);
            print('----------');

    except tf.errors.OutOfRangeError:
        print("Done looping through the file")

    finally:
        coord.request_stop()

    coord.join(threads)