在sess.run()中训练时的数据类型不匹配错误

时间:2018-09-30 13:39:11

标签: tensorflow neural-network generative-adversarial-network

这在某处可能是非常愚蠢的错误,但是,运行受监控的培训课程会返回此错误

InvalidArgumentError: Key: label. Data types don't match. Data type: string but expected type: float

完整追溯:

F:\Apps\Python3\python.exe C:/Users/tester/PycharmProjects/SpectrogramAnalysis/gpu/train_wavegan.py train ./train --data_dir ./out_dir
WARNING:tensorflow:From C:\Users\tester\PycharmProjects\SpectrogramAnalysis\gpu\loader.py:60: batch_and_drop_remainder (from tensorflow.contrib.data.python.ops.batching) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.batch(..., drop_remainder=True)`.
--------------------------------------------------------------------------------
Generator vars
[100, 16384] (1638400): G/z_project/dense/kernel:0
[16384] (16384): G/z_project/dense/bias:0
[1, 25, 512, 1024] (13107200): G/upconv_0/conv2d_transpose/kernel:0
[512] (512): G/upconv_0/conv2d_transpose/bias:0
[1, 25, 256, 512] (3276800): G/upconv_1/conv2d_transpose/kernel:0
[256] (256): G/upconv_1/conv2d_transpose/bias:0
[1, 25, 128, 256] (819200): G/upconv_2/conv2d_transpose/kernel:0
[128] (128): G/upconv_2/conv2d_transpose/bias:0
[1, 25, 64, 128] (204800): G/upconv_3/conv2d_transpose/kernel:0
[64] (64): G/upconv_3/conv2d_transpose/bias:0
[1, 25, 1, 64] (1600): G/upconv_4/conv2d_transpose/kernel:0
[1] (1): G/upconv_4/conv2d_transpose/bias:0
Total params: 19065345 (72.73 MB)
--------------------------------------------------------------------------------
Discriminator vars
[25, 1, 64] (1600): D/downconv_0/conv1d/kernel:0
[64] (64): D/downconv_0/conv1d/bias:0
[25, 64, 128] (204800): D/downconv_1/conv1d/kernel:0
[128] (128): D/downconv_1/conv1d/bias:0
[25, 128, 256] (819200): D/downconv_2/conv1d/kernel:0
[256] (256): D/downconv_2/conv1d/bias:0
[25, 256, 512] (3276800): D/downconv_3/conv1d/kernel:0
[512] (512): D/downconv_3/conv1d/bias:0
[25, 512, 1024] (13107200): D/downconv_4/conv1d/kernel:0
[1024] (1024): D/downconv_4/conv1d/bias:0
[16384, 1] (16384): D/output/dense/kernel:0
[1] (1): D/output/dense/bias:0
Total params: 17427969 (66.48 MB)
--------------------------------------------------------------------------------
2018-09-30 20:55:23.672476: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2018-09-30 20:55:42.424696: W T:\src\github\tensorflow\tensorflow\core\framework\op_kernel.cc:1275] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Key: label.  Data types don't match. Data type: string but expected type: float
Traceback (most recent call last):
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1278, in _do_call
    return fn(*args)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1263, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1350, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: label.  Data types don't match. Data type: string but expected type: float
     [[Node: ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_FLOAT, DT_FLOAT], dense_keys=["label", "samples"], dense_shapes=[[?], [?,1]], num_sparse=0, sparse_keys=[], sparse_types=[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const)]]
     [[Node: loader/IteratorGetNext = IteratorGetNext[output_shapes=[[64,16374,1], [64,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](loader/OneShotIterator)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/tester/PycharmProjects/SpectrogramAnalysis/gpu/train_wavegan.py", line 535, in <module>
    train(fps, args)
  File "C:/Users/tester/PycharmProjects/SpectrogramAnalysis/gpu/train_wavegan.py", line 193, in train
    sess.run(D_train_op)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 583, in run
    run_metadata=run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 1059, in run
    run_metadata=run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 1150, in run
    raise six.reraise(*original_exc_info)
  File "F:\Apps\Python3\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 1135, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 1207, in run
    run_metadata=run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\training\monitored_session.py", line 987, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 877, in run
    run_metadata_ptr)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1100, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1272, in _do_run
    run_metadata)
  File "C:\Users\tester\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\client\session.py", line 1291, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: label.  Data types don't match. Data type: string but expected type: float
     [[Node: ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_FLOAT, DT_FLOAT], dense_keys=["label", "samples"], dense_shapes=[[?], [?,1]], num_sparse=0, sparse_keys=[], sparse_types=[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const)]]
     [[Node: loader/IteratorGetNext = IteratorGetNext[output_shapes=[[64,16374,1], [64,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](loader/OneShotIterator)]]

Process finished with exit code 1

我使用了这个训练wavegan的脚本

from __future__ import print_function
import os
import time

import numpy as np
import tensorflow as tf
from six.moves import xrange

import loader
from wavegan import WaveGANGenerator, WaveGANDiscriminator
from functools import reduce

"""
  Constants
"""
_FS = 16000
_WINDOW_LEN = 16374
_D_Z = 90
_D_Y = 10

"""
  Trains a WaveGAN
"""


def train(fps, args):
    with tf.name_scope('loader'):
        x, y = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, labels=True)

    # Make inputs
    y_fill = tf.expand_dims(y, axis=2)
    z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32)

    # Concatenate labels
    x = tf.concat([x, y_fill], 1)
    z = tf.concat([z, y], 1)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, _FS)
    tf.summary.audio('G_z', G_z, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_G_z,
            labels=real
        ))

        D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_G_z,
            labels=fake
        ))
        D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=D_x,
            labels=real
        ))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
        D_loss = tf.reduce_mean((D_x - 1.) ** 2)
        D_loss += tf.reduce_mean(D_G_z ** 2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
                    )
                )
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(
            learning_rate=2e-4,
            beta1=0.5)
        D_opt = tf.train.AdamOptimizer(
            learning_rate=2e-4,
            beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(
            learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(
            learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(
            learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(
            learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(
            learning_rate=1e-4,
            beta1=0.5,
            beta2=0.9)
        D_opt = tf.train.AdamOptimizer(
            learning_rate=1e-4,
            beta1=0.5,
            beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
                                global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        while True:
            # Train discriminator
            for i in range(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)


"""
  Creates and saves a MetaGraphDef for simple inference
  Tensors:
    'samp_z_n' int32 []: Sample this many latent vectors
    'samp_z' float32 [samp_z_n, 90]: Resultant latent vectors
    'z:0' float32 [None, 100]: Input latent vectors
    'flat_pad:0' int32 []: Number of padding samples to use when flattening batch to a single audio file
    'G_z:0' float32 [None, 16384, 1]: Generated outputs
    'G_z_int16:0' int16 [None, 16384, 1]: Same as above but quantizied to 16-bit PCM samples
    'G_z_flat:0' float32 [None, 1]: Outputs flattened into single audio file
    'G_z_flat_int16:0' int16 [None, 1]: Same as above but quantized to 16-bit PCM samples
  Example usage:
    import tensorflow as tf
    tf.reset_default_graph()

    saver = tf.train.import_meta_graph('infer.meta')
    graph = tf.get_default_graph()
    sess = tf.InteractiveSession()
    saver.restore(sess, 'model.ckpt-10000')

    z_n = graph.get_tensor_by_name('samp_z_n:0')
    _z = sess.run(graph.get_tensor_by_name('samp_z:0'), {z_n: 10})

    z = graph.get_tensor_by_name('G_z:0')
    _G_z = sess.run(graph.get_tensor_by_name('G_z:0'), {z: _z})
"""


def infer(args):
    infer_dir = os.path.join(args.train_dir, 'infer')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    # Subgraph that generates latent vectors
    samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    samp_z = tf.random_uniform([samp_z_n, _D_Z], -1.0, 1.0, dtype=tf.float32, name='samp_z')

    # Input zo
    z = tf.placeholder(tf.float32, [None, _D_Z + _D_Y], name='z')
    flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')

    # Execute generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
    G_z = tf.identity(G_z, name='G_z')

    # Flatten batch
    nch = int(G_z.get_shape()[-1])
    G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
    G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')

    # Encode to int16
    def float_to_int16(x, name=None):
        x_int16 = x * 32767.
        x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
        x_int16 = tf.cast(x_int16, tf.int16, name=name)
        return x_int16

    G_z_int16 = float_to_int16(G_z, name='G_z_int16')
    G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')

    # Create saver
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    global_step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(G_vars + [global_step])

    # Export graph
    tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')

    # Export MetaGraph
    infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
    tf.train.export_meta_graph(
        filename=infer_metagraph_fp,
        clear_devices=True,
        saver_def=saver.as_saver_def())

    # Reset graph (in case training afterwards)
    tf.reset_default_graph()


"""
  Generates a preview audio file every time a checkpoint is saved
"""


def preview(args):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from scipy.io.wavfile import write as wavwrite
    from scipy.signal import freqz

    preview_dir = os.path.join(args.train_dir, 'preview')
    if not os.path.isdir(preview_dir):
        os.makedirs(preview_dir)

    # Load graph
    infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
    graph = tf.get_default_graph()
    saver = tf.train.import_meta_graph(infer_metagraph_fp)

    # Generate or restore z_i and z_o
    z_fp = os.path.join(preview_dir, 'z.pkl')
    if os.path.exists(z_fp):
        with open(z_fp, 'rb') as f:
            _zs = pickle.load(f)
    else:
        # Sample z
        samp_feeds = {}
        samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n
        samp_fetches = {}
        samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0')
        with tf.Session() as sess:
            _samp_fetches = sess.run(samp_fetches, samp_feeds)
        _zs = _samp_fetches['zs']

        # Save z
        with open(z_fp, 'wb') as f:
            pickle.dump(_zs, f)

    # label to one hot vector
    sample_n = 20
    one_hot = np.zeros([sample_n, _D_Y])
    _zs = _zs[:sample_n]
    for i in range(10):
        one_hot[2 * i + 1][i] = 1
        one_hot[2 * i][i] = 1
    _zs = np.concatenate([_zs, one_hot], 1)

    # Set up graph for generating preview images
    feeds = {}
    feeds[graph.get_tensor_by_name('z:0')] = _zs
    feeds[graph.get_tensor_by_name('flat_pad:0')] = _WINDOW_LEN // 2
    fetches = {}
    fetches['step'] = tf.train.get_or_create_global_step()
    fetches['G_z'] = graph.get_tensor_by_name('G_z:0')
    fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0')
    if args.wavegan_genr_pp:
        fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0]

    # Summarize
    G_z = graph.get_tensor_by_name('G_z_flat:0')
    summaries = [
        tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), _FS, max_outputs=1)
    ]
    fetches['summaries'] = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(preview_dir)

    # PP Summarize
    if args.wavegan_genr_pp:
        pp_fp = tf.placeholder(tf.string, [])
        pp_bin = tf.read_file(pp_fp)
        pp_png = tf.image.decode_png(pp_bin)
        pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0))

    # Loop, waiting for checkpoints
    ckpt_fp = None
    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            print('Preview: {}'.format(latest_ckpt_fp))

            with tf.Session() as sess:
                saver.restore(sess, latest_ckpt_fp)

                _fetches = sess.run(fetches, feeds)

                _step = _fetches['step']

            gen_speech = _fetches['G_z_flat_int16']
            gen_len = len(gen_speech) / sample_n

            for i in range(sample_n):
                label = int(i / 2)
                start = i * gen_len
                end = start + gen_len
                preview_fp = os.path.join(preview_dir, '{}_{}_{}.wav'.format(str(label), str(_step), str(i)))
                wavwrite(preview_fp, _FS, gen_speech[start:end])

            summary_writer.add_summary(_fetches['summaries'], _step)

            if args.wavegan_genr_pp:
                w, h = freqz(_fetches['pp_filter'])

                fig = plt.figure()
                plt.title('Digital filter frequncy response')
                ax1 = fig.add_subplot(111)

                plt.plot(w, 20 * np.log10(abs(h)), 'b')
                plt.ylabel('Amplitude [dB]', color='b')
                plt.xlabel('Frequency [rad/sample]')

                ax2 = ax1.twinx()
                angles = np.unwrap(np.angle(h))
                plt.plot(w, angles, 'g')
                plt.ylabel('Angle (radians)', color='g')
                plt.grid()
                plt.axis('tight')

                _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8)))
                plt.savefig(_pp_fp)

                with tf.Session() as sess:
                    _summary = sess.run(pp_summary, {pp_fp: _pp_fp})
                    summary_writer.add_summary(_summary, _step)

            print('Done')

            ckpt_fp = latest_ckpt_fp

        time.sleep(1)


if __name__ == '__main__':
    import argparse
    import glob
    import sys

    parser = argparse.ArgumentParser()

    parser.add_argument('mode', type=str, choices=['train', 'preview', 'incept', 'infer'])
    parser.add_argument('train_dir', type=str,
                        help='Training directory')

    data_args = parser.add_argument_group('Data')
    data_args.add_argument('--data_dir', type=str,
                           help='Data directory')
    data_args.add_argument('--data_first_window', action='store_true', dest='data_first_window',
                           help='If set, only use the first window from each audio example')

    wavegan_args = parser.add_argument_group('WaveGAN')
    wavegan_args.add_argument('--wavegan_kernel_len', type=int,
                              help='Length of 1D filter kernels')
    wavegan_args.add_argument('--wavegan_dim', type=int,
                              help='Dimensionality multiplier for model of G and D')
    wavegan_args.add_argument('--wavegan_batchnorm', action='store_true', dest='wavegan_batchnorm',
                              help='Enable batchnorm')
    wavegan_args.add_argument('--wavegan_disc_nupdates', type=int,
                              help='Number of discriminator updates per generator update')
    wavegan_args.add_argument('--wavegan_loss', type=str, choices=['dcgan', 'lsgan', 'wgan', 'wgan-gp'],
                              help='Which GAN loss to use')
    wavegan_args.add_argument('--wavegan_genr_upsample', type=str, choices=['zeros', 'nn', 'lin', 'cub'],
                              help='Generator upsample strategy')
    wavegan_args.add_argument('--wavegan_genr_pp', action='store_true', dest='wavegan_genr_pp',
                              help='If set, use post-processing filter')
    wavegan_args.add_argument('--wavegan_genr_pp_len', type=int,
                              help='Length of post-processing filter for DCGAN')
    wavegan_args.add_argument('--wavegan_disc_phaseshuffle', type=int,
                              help='Radius of phase shuffle operation')

    train_args = parser.add_argument_group('Train')
    train_args.add_argument('--train_batch_size', type=int,
                            help='Batch size')
    train_args.add_argument('--train_save_secs', type=int,
                            help='How often to save model')
    train_args.add_argument('--train_summary_secs', type=int,
                            help='How often to report summaries')

    preview_args = parser.add_argument_group('Preview')
    preview_args.add_argument('--preview_n', type=int,
                              help='Number of samples to preview')

    incept_args = parser.add_argument_group('Incept')
    incept_args.add_argument('--incept_metagraph_fp', type=str,
                             help='Inference model for inception score')
    incept_args.add_argument('--incept_ckpt_fp', type=str,
                             help='Checkpoint for inference model')
    incept_args.add_argument('--incept_n', type=int,
                             help='Number of generated examples to test')
    incept_args.add_argument('--incept_k', type=int,
                             help='Number of groups to test')

    parser.set_defaults(
        data_dir=None,
        data_first_window=False,
        wavegan_kernel_len=25,
        wavegan_dim=64,
        wavegan_batchnorm=False,
        wavegan_disc_nupdates=5,
        wavegan_loss='wgan-gp',
        wavegan_genr_upsample='zeros',
        wavegan_genr_pp=False,
        wavegan_genr_pp_len=512,
        wavegan_disc_phaseshuffle=2,
        train_batch_size=64,
        train_save_secs=300,
        train_summary_secs=120,
        preview_n=32,
        incept_metagraph_fp='./eval/inception/infer.meta',
        incept_ckpt_fp='./eval/inception/best_acc-103005',
        incept_n=5000,
        incept_k=10)

    args = parser.parse_args()

    # Make train dir
    if not os.path.isdir(args.train_dir):
        os.makedirs(args.train_dir)

    # Save args
    with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f:
        f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))

    # Make model kwarg dicts
    setattr(args, 'wavegan_g_kwargs', {
        'kernel_len': args.wavegan_kernel_len,
        'dim': args.wavegan_dim,
        'use_batchnorm': args.wavegan_batchnorm,
        'upsample': args.wavegan_genr_upsample
    })
    setattr(args, 'wavegan_d_kwargs', {
        'kernel_len': args.wavegan_kernel_len,
        'dim': args.wavegan_dim,
        'use_batchnorm': args.wavegan_batchnorm,
        'phaseshuffle_rad': args.wavegan_disc_phaseshuffle
    })

    # Assign appropriate split for mode
    if args.mode == 'train':
        split = 'train'
    else:
        split = None

    # Find fps for split
    if split is not None:
        fps = glob.glob(os.path.join(args.data_dir, split) + '*.tfrecord')

    if args.mode == 'train':
        infer(args)
        train(fps, args)
    elif args.mode == 'preview':
        preview(args)
    elif args.mode == 'infer':
        infer(args)
    else:
        raise NotImplementedError()

谢谢

0 个答案:

没有答案