如何使用Nvidia Progressive_GAN修复“ AssertionError”?

时间:2019-07-25 22:27:40

标签: python tensorflow machine-learning neural-network

我一直在尝试学习一些有关Tensorflow的知识。资料来源:https://github.com/tkarras/progressive_growing_of_gans。好消息是,我能够使用训练后的模型karras2018iclr-celebahq-1024x1024.pkl中的正确输出图像样本成功完成import_example.py。现在,我使用dataset_tool.py(create_from_images)创建了自己的数据集(* .tfrecords),其中包含400x的1024x1024图像。我修改了config.py

至:     data_dir ='C:/ Users / Anaconda3 / envs / ProgressiveGAN / source / datasets / MYIMAGES /'     result_dir ='C:/ Users / Anaconda3 / envs / ProgressiveGAN / source / results /'

并创建了一个新的数据集。

但是我得到了:

(ProgressiveGAN) C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018>python train.py
Initializing TensorFlow...
Running train.train_progressive_gan()...
Streaming data using dataset.TFRecordDataset...
Traceback (most recent call last):
File "train.py", line 285, in <module>
tfutil.call_func_by_name(**config.train)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\tfutil.py", line 236, in call_func_by_name
return import_obj(func)(*args, **kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\train.py", line 151, in train_progressive_gan
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 234, in load_dataset
dataset = tfutil.import_obj(class_name)(**adjusted_kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 67, in __init__
assert os.path.isdir(self.tfrecord_dir)
AssertionError

config.py代码

# Paths.

data_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES'
result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/'

# Official training configs, targeted mainly for CelebA-HQ.
# To run, comment/uncomment the lines as appropriate and launch train.py.

desc        = 'pgan'                                        # Description     string included in result subdir name.
random_seed = 1000                                          # Global random seed.
dataset     = EasyDict()                                    # Options for dataset.load_dataset().
train       = EasyDict(func='train.train_progressive_gan')  # Options for main training func.
G           = EasyDict(func='networks.G_paper')             # Options for generator network.
D           = EasyDict(func='networks.D_paper')             # Options for discriminator network.
G_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss      = EasyDict(func='loss.G_wgan_acgan')            # Options for generator loss.
D_loss      = EasyDict(func='loss.D_wgangp_acgan')          # Options for discriminator loss.
sched       = EasyDict()                                    # Options for train.TrainingSchedule.
grid        = EasyDict(size='1080p', layout='random')       # Options for train.setup_snapshot_image_grid().

# Dataset (choose one).
desc += '-MYIMAGES';            dataset = EasyDict(tfrecord_dir='MYIMAGES'); train.mirror_augment = True

我想训练自己的1024张1024x1024图像。

1 个答案:

答案 0 :(得分:0)

我最近遇到了这个问题。这是一个非常简单的修复程序。

所有您需要做的就是将 tfrecord_dir 放入目录数据集,即tfrecord_dir / datasets。