如何使用发电机生成输入来训练TensorFlow网络?

时间:2016-09-05 06:57:38

标签: python tensorflow

TensorFlow docs描述了使用TFRecordReader,TextLineReader,QueueRunner等和队列读取数据的一系列方法。

我想做的事情要简单得多:我有一个python生成器函数,它产生无限的训练数据序列为(X,y)元组(两者都是numpy数组,第一个维度是批量大小)。我只想用这些数据作为输入来训练网络。

是否有一个简单的自包含示例,使用生成数据的生成器训练TensorFlow网络? (沿着MNIST或CIFAR的例子)

1 个答案:

答案 0 :(得分:27)

假设您有一个生成数据的函数:

 def generator(data): 
    ...
    yield (X, y)

现在您需要另一个描述模型架构的功能。它可以是处理X的任何函数,并且必须将y预测为输出(例如,神经网络)。

假设你的函数接受X和y作为输入,以某种方式从X计算y的预测并返回y和预测y之间的损失函数(例如,回归的交叉熵或MSE):

 def neural_network(X, y): 
    # computation of prediction for y using X
    ...
    return loss(y, y_pred)

要使模型有效,您需要为X和y定义占位符,然后运行会话:

 X = tf.placeholder(tf.float32, shape=(batch_size, x_dim))
 y = tf.placeholder(tf.float32, shape=(batch_size, y_dim))

占位符类似于"自由变量"您需要在feed_dict运行会话时指定:

 with tf.Session() as sess:
     # variables need to be initialized before any sess.run() calls
     tf.global_variables_initializer().run()

     for X_batch, y_batch in generator(data):
         feed_dict = {X: X_batch, y: y_batch} 
         _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict)
         # train_op here stands for optimization operation you have defined
         # and loss for loss function (return value of neural_network function)

希望你会发现它很有用。但是,请记住,这不是完全有效的实现,而是伪代码,因为您几乎没有指定任何细节。