如何从Tensorflow模型还原GAN的生成器?

时间:2019-06-19 22:06:12

标签: python tensorflow neural-network generative-adversarial-network

我正在尝试使用Tensorflow模型(元数据和检查点)恢复训练有素的对抗网络生成器

我是tensorflow和python的新手,所以我不确定我在做什么是否有意义。已经尝试从元文件中导入元图并从检查点还原变量,但是我确定下一步该怎么做。我的目标是从最后一个检查点恢复训练有素的生成器,然后使用它从噪声输入中生成新数据。

以下是包含模型文件的驱动器的链接: https://drive.google.com/drive/folders/1MaELMC4aOroSQlMJ32J3_ff3wxiBT_Fq?usp=sharing

到目前为止,我已经尝试了以下方法,并且似乎正在加载图形:


# import the graph from the file
imported_graph = tf.train.import_meta_graph("../../models/model-9.meta")

# list all the tensors in the graph
for tensor in tf.get_default_graph().get_operations():
    print (tensor.name)

# run the session
with tf.Session() as sess:
    # restore the saved vairable
    imported_graph.restore(sess, "../../models/model-9")

但是,我不确定下一步该怎么做。使用此文件是否可以仅运行受过训练的生成器?我该如何访问?

1 个答案:

答案 0 :(得分:0)

在Tensorflow 2文档中,他们同时保存生成器和鉴别器。但是,他们没有说明如何仅还原生成器。

-1

然后使用

还原
int main()
{
    srand(time(0));
    int input;
    std::string pass, company, timeS;
    do
    {
        std::cout << "Enter password length: ";
        std::cin >> input;
        while((input < 8 && input>=0)|| input > 16)
        {
            if(!std::cin)
            {
                std::cin.clear();
                std::cin.ignore(100, '\n');
            }

            std::cout << "Password length must be between 8 and 16.\nEnter password length: ";
            std::cin >> input;
        }

        if (input>0){
          std::cout << "Enter company name: ";
          std::getline(std::cin, company);

          pass = passGen(input);
          time_t now = time(0);
          auto time = *std::localtime(&now);
          std::stringstream ss;
          ss << std::put_time(&time, "%Y %b %d %H:%M:%S %a");
          timeS = ss.str();
          std::cout << "You passoword: " << pass << std::endl;
          writeFile(pass, company, timeS);
        }
    }while(input != -1);

    return 0;
}

来自https://www.tensorflow.org/tutorials/generative/dcgan#save_checkpoints