使用输入管道访问图形中的输入变量

时间:2018-07-11 14:01:38

标签: python tensorflow tensorflow-datasets

假设有一个定义为此类的模型:

class SimpleAutoencoder(object):

    def __init__(self, x):

        self.x = x
        self.input_dim = 92
        self.latent_dim = 10

        self.build_model()


    def build_model(self):

        latent = tf.contrib.layers.fully_connected(self.x,
                                                   self.latent_dim,
                                                   scope='latent',
                                                   activation_fn=tf.nn.relu)

        self.x_hat = tf.contrib.layers.fully_connected(latent,
                                                  self.input_dim,
                                                  scope='output',
                                                  activation_fn=tf.nn.sigmoid)

        self.loss = tf.losses.mean_squared_error(self.x, self.x_hat)

        self.train_op = tf.AdamOptimizer().minimize(self.loss)

您可以使用输入管道来训练数据,以训练数据:

...
x = iterator.get_next()
model = SimpleAutoencoder(x)
...

## train and save it to disk

现在,在构建模型时使用self.x的占位符时,我可以给它命名,并在恢复模型以进行推断时轻松访问输入变量。但是对于输入管道,x不能是变量,常量或占位符,因此我不能给它起一个合适的名称。 如何将新数据注入x并通过图表输入?

即使培训有效,我仍认为我可能会以某种方式做错了,因为代码对我来说看起来确实很丑陋(将管道输出提供给init函数的部分)。

请帮助我解决这个问题!谢谢!

1 个答案:

答案 0 :(得分:1)

  • 您可以使用x来获得x.name的名字,
  • 或者,您可以使用xx = tf.identity(x, name='my_name')重命名为您喜欢的名称,

    (通过这两种解决方案,您可以使用张量的名称来输入值-即使x不是占位符:

    sess.run(my_ops, feed_dict{tensor_name: tensor_value})
    

  • 或者,您可以用占位符替换整个输入管道(对于相反的问题,解释为here-用Dataset输入代替占位符)