Tensorflow TFGAN GANEstimator出口模型

时间:2018-01-25 17:10:11

标签: python tensorflow

我花了一些令人沮丧的时间来编写一个简单的模型导出和导入与GANEstimator但我无法让它工作。基本上它是带有GANEstimator的条件GAN,最后4个LOC是导出。最后一行抛出:

ValueError: export_outputs must be a dict and not<class 'NoneType'>

如果有人能看一眼,我真的很感激。正如我所说,我只是想训练模型并导出它,并在另一个python脚本中,我想重新加载模型(而不是说生成器)并以某种方式输入数据。这是我的代码:

import tensorflow as tf
#import tensorflow.contrib.eager as tfe
#tfe.enable_eager_execution()

tfgan = tf.contrib.gan
slim = tf.contrib.slim
layers = tf.contrib.layers
ds = tf.contrib.distributions

import time
import datasets.download_and_convert_mnist as download_and_convert_mnist
from mnist import data_provider,util
import os
import matplotlib.pyplot as plt
import numpy as np
import scipy

from common import *
MODEL_FILE_NAME = 'cond_garden_experimental'
MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_PATH, MODEL_FILE_NAME)
IMGS_SAVE_PATH = os.path.join(IMGS_SAVE_PATH, MODEL_FILE_NAME)




#constants and variables
num_epochs = 2000
batch_size = 32
latent_dims = 64

from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops

def _get_shape(tensor):
  tensor_shape = array_ops.shape(tensor)
  static_tensor_shape = tensor_util.constant_value(tensor_shape)
  return (static_tensor_shape if static_tensor_shape is not None else
          tensor_shape)

def condition_tensor(tensor, conditioning):
  tensor.shape[1:].assert_is_fully_defined()
  num_features = tensor.shape[1:].num_elements()
  mapped_conditioning = layers.linear(
      layers.flatten(conditioning), num_features)
  print(mapped_conditioning.shape)
  if not mapped_conditioning.shape.is_compatible_with(tensor.shape):
    mapped_conditioning = array_ops.reshape(
        mapped_conditioning, _get_shape(tensor))
  return tensor + mapped_conditioning




def conditional_discriminator_fn(img, inputs, weight_decay=2.5e-5):
    one_hot_labels = inputs[1]
    with slim.arg_scope(
            [layers.conv2d, layers.fully_connected],
            activation_fn=leaky_relu, normalizer_fn=None,
            weights_regularizer=layers.l2_regularizer(weight_decay),
            biases_regularizer=layers.l2_regularizer(weight_decay)):
        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)
        net = condition_tensor(net, one_hot_labels)
        net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)

leaky_relu = lambda net: tf.nn.leaky_relu(net, alpha=0.01)

global_noise = None
global_condition = None

def conditional_generator_fn(inputs, weight_decay=2.5e-5):
    if isinstance(inputs,dict):
        noise, one_hot_labels = inputs['noise'],inputs['condition']
    else:
        noise, one_hot_labels = inputs
        global global_noise
        global global_condition
        global_noise = noise
        global_condition = one_hot_labels
    with slim.arg_scope(
            [layers.fully_connected, layers.conv2d_transpose],
            activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
            weights_regularizer=layers.l2_regularizer(weight_decay)):
        net = layers.fully_connected(noise, 1024)
        net = condition_tensor(net, one_hot_labels)
        net = layers.fully_connected(net, 7 * 7 * 128)
        net = tf.reshape(net, [-1, 7, 7, 128])
        net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
        net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
        # Make sure that generator output is in the same range as `inputs`
        # ie [-1, 1].
        net = layers.conv2d(net, 1, 4, normalizer_fn=None, activation_fn=tf.tanh)

        return net


def _get_train_input_fn(batch_size, noise_dims, dataset_dir=None, num_threads=4):
  def train_input_fn():
    with tf.device('/cpu:0'):
      images, labels, _ = data_provider.provide_data('train', batch_size, dataset_dir, num_threads=num_threads)
    noise = tf.random_normal([batch_size, noise_dims])
    return ((noise,labels),images)
  return train_input_fn


def _get_predict_input_fn():
  def predict_input_fn(params):
      noise,condition = params['noise'],params['condition']
      noise_tensor = tf.convert_to_tensor(noise)
      condition_tensor = tf.convert_to_tensor(condition)
      #with tf.device('/cpu:0'):
      #    images, condition_tensor, _ = data_provider.provide_data('train', batch_size, MNIST_DATA_DIR, num_threads=4)
      #    noise_tensor = tf.random_normal([batch_size, latent_dims])
      print(noise_tensor.shape,condition_tensor.shape)
      return ((noise_tensor,condition_tensor),None)

  return predict_input_fn


def visualize_training_generator(train_step_num, start_time, data_np):
    """Visualize generator outputs during training.

    Args:
        train_step_num: The training step number. A python integer.
        start_time: Time when training started. The output of `time.time()`. A
            python float.
        data: Data to plot. A numpy array, most likely from an evaluated TensorFlow
            tensor.
    """
    print('Training step: %i' % train_step_num)
    time_since_start = (time.time() - start_time) / 60.0
    print('Time since start: %f m' % time_since_start)
    print('Steps per min: %f' % (train_step_num / time_since_start))
    plt.axis('off')
    plt.imshow(np.squeeze(data_np), cmap='gray')
    plt.savefig(os.path.join(IMGS_SAVE_PATH,MODEL_FILE_NAME+str(train_step_num)+'.png'))


if __name__ == '__main__':
    setup_clean_directory(MODEL_SAVE_PATH)
    setup_clean_directory(IMGS_SAVE_PATH)
    #prepare data
    download_and_convert_mnist.run(MNIST_DATA_DIR)
    with tf.device('/cpu:0'):
        images, one_hot_labels, _ = data_provider.provide_data('train', batch_size, MNIST_DATA_DIR)

    gan_estimator = tfgan.estimator.GANEstimator(
        MODEL_SAVE_PATH,
        generator_fn=conditional_generator_fn,
        discriminator_fn=conditional_discriminator_fn,
        generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
        generator_optimizer=tf.train.AdamOptimizer(0.001, 0.5),
        discriminator_optimizer=tf.train.AdamOptimizer(0.0001, 0.5),
        add_summaries=tfgan.estimator.SummaryType.IMAGES)
    train_input_fn = _get_train_input_fn(batch_size, noise_dims=latent_dims, dataset_dir=MNIST_DATA_DIR)
    gan_estimator.train(train_input_fn, max_steps=1)


    #try 1
    from tensorflow.python.estimator.export import export
    feat_dict = {'noise':global_noise,'condition':global_condition}
    sirf = export.build_raw_serving_input_receiver_fn(feat_dict)
    gan_estimator.export_savedmodel(EXPORT_DIR_ROOT, sirf)

1 个答案:

答案 0 :(得分:5)

2018-5-09 :目前GANEstimator没有创建上述export_output字典;可以通过在第162行之后查看tensorflow/contrib/gan/python/estimator/python/head_impl.py来验证它。

您可以通过此GitHub Pull Request跟踪此问题的状态。

model_fn需要使用EstimatorSpec字典返回export_outputs。有点像:

if mode == Modes.PREDICT:
    predictions = {
        'classes': tf.gather(label_values, predicted_indices),
        'scores': tf.reduce_max(probabilities, axis=1)
    }
    export_outputs = {
        'prediction': tf.estimator.export.PredictOutput(predictions)
    }
    return tf.estimator.EstimatorSpec(
        mode, predictions=predictions, export_outputs=export_outputs)