我正在使用tf.Estimator
训练变体自动编码器。 model_fn
包含一个编码器和一个解码器,并以图像作为输入。在部署或评估阶段,我希望模型采用随机采样的潜在代码作为输入,并仅执行解码器部分。
我可以从估算器中解开编码器和解码器吗?
答案 0 :(得分:0)
如果您将variable_scope用于解码器,则可以在model_fn
中生成一个随机样本,然后通过使用与{同名的variable_scope
创建一个reuse=True
,将其再次通过解码器{1}}。
伪代码:
def encoder_fn(image, ...):
...
return latent
def decoder_fn(latent, ...):
...
return reconstruction
def model_fn(...):
...
with tf.variable_scope('encoder'):
predictions['latent'] = encoder_fn(features[...], ...)
with tf.variable_scope('decoder'):
predictions['reconstruction'] = decoder_fn(predictions['latent'], ...)
with tf.variable_scope('sample'):
random_latent = ...
with tf.variable_scope('decoder', reuse=True):
random_reconstruction = decoder_fn(random_latent, ...)
...
对于高级用户,另一个提示是,如果您将潜在代码添加到功能字典中,它将显示在导出的已保存模型的输入下。由于您可以只将保存的模型加载到会话中,这意味着您可以轻松地直接访问和输入潜在代码。不仅可以填充占位符,还可以填充更多的内容。确实没有足够宣传的功能。
示例
# Train side:
# 1. Create estimator
# 2. Train
# 3. Export savedmodel
def model_fn(...):
...
with tf.variable_scope('encoder'):
predictions['latent'] = encoder_fn(features[...], ...)
features['latent'] = predictions['latent']
with tf.variable_scope('decoder'):
predictions['reconstruction'] = decoder_fn(predictions['latent'], ...)
...
# Client side,
# 1. Launch session
# 2. Load savedmodel
# We can reconstruct images
reconstructed_images = session.run(outputs['reconstruction'], feed_dict={inputs['images'] : my_images))
# And generate directly from latent code
random_images = session.run(outputs['reconstruction'], feed_dict={inputs['latents'] : np.random.normal(...)})
# All in the same graph!