在feed_dict和队列之间轻松切换,以输入TensorFlow模型

时间:2017-07-17 00:56:40

标签: python tensorflow

现在我有一个模型配置为使用feed_dict获取其输入。代码看起来像这样:

# model.py
class MyModel(object):
  def __init__(self, hyperparams):
    self.build_model(hyperparams)

  def build_model(self, hps):
    self.input_data = tf.placeholder(dtype=tf.float32, shape=[hps.batch_size, hps.nfeats])
    self.labels = tf.placeholder(dtype=tf.float32, shape=[hps.batch_size])
    # Define hidden layers, loss, training step, etc.

# train.py
model = MyModel(hps)
for _ in range(100):
  x, y = some_python_function() # Read a batch from disk, preprocess
  sess.run(model.train_step, feed_dict={model.input_data: x, model.labels: y})

出于性能原因,我想切换到使用队列进行培训。但是我希望保持使用feed_dict的能力,例如用于推理或测试。

有优雅的方法吗?我想做的是,在使用队列时,换掉'我的队列出列运算返回的张量的占位符变量。我认为tf.assign是这样做的方式,即:

single_x, single_y = tf.parse_single_example(...)
x, y = tf.train.batch([single_x, single_y], batch_size)
model = MyModel(hps)
sess.run([tf.assign(model.input_data, x), tf.assign(model.labels, y)])
for _ in range(100):
  sess.run(model.train_step)

但这会引发AttributeError: 'Tensor' object has no attribute 'assign'tf.assign的API文档将第一个参数描述为:" A mutable Tensor。应该来自Variable节点。可能是未初始化的。"这是否意味着我的占位符不可变?我可以这样做吗?或者我是以错误的方式接近这个?

最小可运行示例here

2 个答案:

答案 0 :(得分:2)

您可以将VariablesOperations的创建分开:

  • build_variables类的实例化时添加Model方法,
  • 更改build_model方法的界面,使其接受您的xy张量作为参数,从而根据它们构建模型operations

这样您就可以重用模型的变量和常量。缺点是,placeholder版本和任何其他版本的操作将重复。

import tensorflow as tf
import numpy as np

BATCH_SIZE = 2

class Model(object):

  def __init__(self):
    self.build_variables()

  def build_variables(self):
    self.w = tf.Variable(tf.random_normal([3, 1]))

  def build_model(self, x, y):
    self.x = x
    self.y = y
    self.output = tf.matmul(self.x, self.w)
    self.loss = tf.losses.absolute_difference(self.y, self.output)


model = Model()
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

def placeholder_run():
  x = tf.placeholder(dtype=tf.float32, shape=[BATCH_SIZE, 3])
  y = tf.placeholder(dtype=tf.float32, shape=[BATCH_SIZE, 1])
  model.build_model(x, y)

  for i in range(3):
    x = np.random.rand(BATCH_SIZE, 3)
    y = x.sum(axis=1, keepdims=True)
    loss = sess.run(model.loss, feed_dict={model.x:x, model.y:y})
    print(loss)

def nonph_run():
  x = tf.random_normal([BATCH_SIZE, 3])
  y = tf.reduce_sum(x, axis=1, keep_dims=True)
  model.build_model(x, y)
  for i in range(3):
    loss = sess.run(model.loss)
    print(loss)

if __name__ == '__main__':
    # Works
    placeholder_run()
    # Doesn't fail
    nonph_run()

答案 1 :(得分:0)

如果您可以控制图表并预先了解自己想要的内容,则可以在输入中使用开关。例如,

x_plh = tf.placeholder(tf.float32, myshape)
x_dsk = my_input_from_disk()
use_dsk = tf.placeholder(tf.bool, ())
x = tf.cond(use_dsk, lambda: x_dsk, lambda: x_plh)

如果你想要一个更灵活的解决方案,并采取一些先锋路线,你可以得到张量流的Dataset API。花点时间阅读文档,这是一个很好的阅读。单个Iterator可以使用不同的Dataset来生成多个初始化程序,这可能适合您的情况。