用tf.data替换基于队列的输入管道

时间:2018-12-01 13:50:14

标签: tensorflow tensorflow-datasets

我正在用Tensorflow阅读Ganegedara的NLP。输入pipieline的介绍包含以下示例

========== Step 0 ==========
Evaluated data (x)
[[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]

========== Step 1 ==========
Evaluated data (x)
[[1.  0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1]
 [1.  0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]

========== Step 2 ==========
Evaluated data (x)
[[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [1.  0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1]
 [1.  0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1]]

========== Step 3 ==========
Evaluated data (x)
[[0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
 [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
 [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]]

========== Step 4 ==========
Evaluated data (x)
[[0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
 [1.  0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]

具有以下输出:

import tensorflow as tf
import numpy as np
import os

graph = tf.Graph()
session = tf.InteractiveSession(graph=graph)
filenames = ['test%d.txt'%i for i in range(1,4)]
record_defaults = [[-1.0]] * 10
features = tf.data.experimental.CsvDataset(filenames, record_defaults).batch(batch_size=3).shuffle(buffer_size=5)
x = features.make_one_shot_iterator().get_next()
x = tf.convert_to_tensor(x)
# Executing operations and evaluating nodes in the graph
tf.global_variables_initializer().run() # Initialize the variables
# Calculate h with x and print the results for 5 steps
for step in range(5):
    x_eval = session.run(x)
    print('========== Step %d =========='%step)
    print('Evaluated data (x)')
    print(x_eval)
    print('')
session.close()

它还会产生很多关于不建议使用基于队列的输入管道的警告,并建议使用tf.data模块。

这是我尝试使用tf.data模块

========== Step 0 ==========
Evaluated data (x)
[[0.1 0.1 0.1]
 [0.2 0.2 0.2]
 [0.3 0.3 0.3]
 [0.4 0.4 0.4]
 [0.5 0.5 0.5]
 [0.6 0.6 0.6]
 [0.7 0.7 0.7]
 [0.8 0.8 0.8]
 [0.9 0.9 0.9]
 [1.  1.  1. ]]

========== Step 1 ==========
Evaluated data (x)
[[0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]
 [0.1 0.1 0.1]]

========== Step 2 ==========
Evaluated data (x)
[[1.  1.  1. ]
 [0.9 0.9 0.9]
 [0.8 0.8 0.8]
 [0.7 0.7 0.7]
 [0.6 0.6 0.6]
 [0.5 0.5 0.5]
 [0.4 0.4 0.4]
 [0.3 0.3 0.3]
 [0.2 0.2 0.2]
 [0.1 0.1 0.1]]

========== Step 3 ==========
Evaluated data (x)
[[0.1 0.1 0.1]
 [0.2 0.2 0.1]
 [0.3 0.3 0.1]
 [0.4 0.4 0.1]
 [0.5 0.5 0.1]
 [0.6 0.6 0.1]
 [0.7 0.7 0.1]
 [0.8 0.8 0.1]
 [0.9 0.9 0.1]
 [1.  1.  0.1]]

========== Step 4 ==========
Evaluated data (x)
[[0.1 1.  1. ]
 [0.1 0.9 0.9]
 [0.1 0.8 0.8]
 [0.1 0.7 0.7]
 [0.1 0.6 0.6]
 [0.1 0.5 0.5]
 [0.1 0.4 0.4]
 [0.1 0.3 0.3]
 [0.1 0.2 0.2]
 [0.1 0.1 0.1]]

将改为生成以下输出:

416x416

看起来原始代码每次都会采样3行,而我对tf.data的尝试采样了3列。为什么会这样?我该如何修正我的代码,使其与原始代码等效?

1 个答案:

答案 0 :(得分:0)

我最终通过别人的代码inquiring about the poor performance of TextLineDataset and decode_csv找到了答案。

这是我的代码,它使用tf.data来执行与Ganegedara的书中的代码类似的事情:

import tensorflow as tf
import numpy as np
import os

graph = tf.Graph()
session = tf.InteractiveSession(graph=graph)
filenames = ['test%d.txt'%i for i in range(1,4)]

record_defaults = [[-1.0]] * 10

features = tf.data.TextLineDataset(filenames=filenames)

def parse_csv(line):
        cols_types = [[-1.0]] * 10  # all required
        columns = tf.decode_csv(line, record_defaults=cols_types)
        return tf.stack(columns)

features = features.map(parse_csv).batch(batch_size=3).shuffle(buffer_size=5)

x = features.make_one_shot_iterator().get_next()
x = tf.convert_to_tensor(x)
W = tf.Variable(tf.random_uniform(shape=[10,5], minval=-0.1,maxval=0.1, dtype=tf.float32),name='W') 
b = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='b')
h = tf.nn.sigmoid(tf.matmul(x,W) + b) # Operation to be performed

tf.global_variables_initializer().run() # Initialize the variables

# Calculate h with x and print the results for 5 steps
for step in range(5):
    x_eval, h_eval = session.run([x,h]) 
    print('========== Step %d =========='%step)
    print('Evaluated data (x)')
    print(x_eval)
    print('Evaluated data (h)')
    print(h_eval)
    print('')
session.close()