输入中间值并在tf.Estimator中评估子图

时间:2019-04-16 07:38:03

标签: python tensorflow tensorflow-estimator

我正在使用tf.Estimator训练变体自动编码器。 model_fn包含一个编码器和一个解码器,并以图像作为输入。在部署或评估阶段,我希望模型采用随机采样的潜在代码作为输入,并仅执行解码器部分。

我可以从估算器中解开编码器和解码器吗?

1 个答案:

答案 0 :(得分:0)

如果您将variable_scope用于解码器,则可以在model_fn中生成一个随机样本,然后通过使用与{同名的variable_scope创建一个reuse=True,将其再次通过解码器{1}}。

伪代码:

def encoder_fn(image, ...):
  ...
  return latent

def decoder_fn(latent, ...):
  ...
  return reconstruction

def model_fn(...):
  ...

  with tf.variable_scope('encoder'):
    predictions['latent'] = encoder_fn(features[...], ...)

  with tf.variable_scope('decoder'):
    predictions['reconstruction'] = decoder_fn(predictions['latent'], ...)

  with tf.variable_scope('sample'):
    random_latent = ...

  with tf.variable_scope('decoder', reuse=True):
    random_reconstruction = decoder_fn(random_latent, ...)

  ...

对于高级用户,另一个提示是,如果您将潜在代码添加到功能字典中,它将显示在导出的已保存模型的输入下。由于您可以只将保存的模型加载到会话中,这意味着您可以轻松地直接访问和输入潜在代码。不仅可以填充占位符,还可以填充更多的内容。确实没有足够宣传的功能。

示例

# Train side:
# 1. Create estimator
# 2. Train
# 3. Export savedmodel
def model_fn(...):
  ...

  with tf.variable_scope('encoder'):
    predictions['latent'] = encoder_fn(features[...], ...)
    features['latent'] = predictions['latent']

  with tf.variable_scope('decoder'):
    predictions['reconstruction'] = decoder_fn(predictions['latent'], ...)

  ...

# Client side,
# 1. Launch session
# 2. Load savedmodel

# We can reconstruct images
reconstructed_images = session.run(outputs['reconstruction'], feed_dict={inputs['images'] : my_images))

# And generate directly from latent code
random_images = session.run(outputs['reconstruction'], feed_dict={inputs['latents'] : np.random.normal(...)})

# All in the same graph!