我一直在尝试学习一些有关Tensorflow的知识。资料来源:https://github.com/tkarras/progressive_growing_of_gans。好消息是,我能够使用训练后的模型karras2018iclr-celebahq-1024x1024.pkl中的正确输出图像样本成功完成import_example.py。现在,我使用dataset_tool.py(create_from_images)创建了自己的数据集(* .tfrecords),其中包含400x的1024x1024图像。我修改了config.py
至: data_dir ='C:/ Users / Anaconda3 / envs / ProgressiveGAN / source / datasets / MYIMAGES /' result_dir ='C:/ Users / Anaconda3 / envs / ProgressiveGAN / source / results /'
并创建了一个新的数据集。
但是我得到了:
(ProgressiveGAN) C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018>python train.py
Initializing TensorFlow...
Running train.train_progressive_gan()...
Streaming data using dataset.TFRecordDataset...
Traceback (most recent call last):
File "train.py", line 285, in <module>
tfutil.call_func_by_name(**config.train)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\tfutil.py", line 236, in call_func_by_name
return import_obj(func)(*args, **kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\train.py", line 151, in train_progressive_gan
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 234, in load_dataset
dataset = tfutil.import_obj(class_name)(**adjusted_kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 67, in __init__
assert os.path.isdir(self.tfrecord_dir)
AssertionError
config.py代码:
# Paths.
data_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES'
result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/'
# Official training configs, targeted mainly for CelebA-HQ.
# To run, comment/uncomment the lines as appropriate and launch train.py.
desc = 'pgan' # Description string included in result subdir name.
random_seed = 1000 # Global random seed.
dataset = EasyDict() # Options for dataset.load_dataset().
train = EasyDict(func='train.train_progressive_gan') # Options for main training func.
G = EasyDict(func='networks.G_paper') # Options for generator network.
D = EasyDict(func='networks.D_paper') # Options for discriminator network.
G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss = EasyDict(func='loss.G_wgan_acgan') # Options for generator loss.
D_loss = EasyDict(func='loss.D_wgangp_acgan') # Options for discriminator loss.
sched = EasyDict() # Options for train.TrainingSchedule.
grid = EasyDict(size='1080p', layout='random') # Options for train.setup_snapshot_image_grid().
# Dataset (choose one).
desc += '-MYIMAGES'; dataset = EasyDict(tfrecord_dir='MYIMAGES'); train.mirror_augment = True
我想训练自己的1024张1024x1024图像。
答案 0 :(得分:0)
我最近遇到了这个问题。这是一个非常简单的修复程序。
所有您需要做的就是将 tfrecord_dir 放入目录数据集,即tfrecord_dir / datasets。