用于条件GAN的TFGAN和数据集框架

时间:2019-02-07 22:49:56

标签: python tensorflow tensorflow-datasets

我正在使用make_one_shot_iterator和get_next函数获取数据,以将数据输入到使用TFGAN创建的条件GAN中。但是,该批次采样了两次。是否有一些技巧可以避免这种情况?

我正在使用TFGAN库和数据集框架对条件GAN进行编程。 GAN模型以17类为条件。我的数据集由178个样本和11个要素组成。我正在使用TFRecords保存数据,并使用迭代器读取数据。我已经尝试了类似教程the Google Cloud documentation中所示的鉴别器和生成器函数。但是,批次大小为5时,经过17次训练迭代后会出现不兼容的形状误差。我的代码是:

dataset_provider.py的第13-25行。这些功能基于数据集框架来采样我的数据集

ALTER PROCEDURE AccidentByMonth
    (
        @ReportedDate int
    )
AS
    BEGIN
        SELECT *
        FROM   Incident
        WHERE  DATEPART(QUARTER, ReportedDate) = DATEPART(QUARTER, CONVERT(DATE,CONVERT(VARCHAR(10),@ReportedDate,101)));
    END;

我的main.py文件

def tfrecord_train_input_fn(tfrecord_path, batch_size=32):
    tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_path)
    tfrecord_dataset = tfrecord_dataset.map(_parse_)
    tfrecord_dataset = tfrecord_dataset.shuffle(True).batch(batch_size)
    tfrecord_iterator = tfrecord_dataset.make_one_shot_iterator()
    data, target = tfrecord_iterator.get_next()
    return data, target


def dataset_provider(path, batch_size=32, subset='train'):
    tfrecord_path = path + subset + '_set.tfrecords'
    print(tfrecord_path)
    return tfrecord_train_input_fn(tfrecord_path, batch_size)

结果,我得到了下一个错误: InvalidArgumentError(请参见上方的追溯):不兼容的形状:[5,1024]与[3,1024]      [[节点生成器/添加(在/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/gan/python/features/python/condition_utils_impl.py:73中定义=添加[T = DT_FLOAT,_device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](发电机/ fully_connected / Relu,Generator / fully_connected_1 / BatchNorm / Reshape_1)]]

我认为对于每次训练迭代,数据集都会被采样两次。下一个推理可以证明这一点: 5个batch_size x 2个采样x 17个迭代_before_error = 178个样本中的170个 生成器模型中再增加1个采样(+ 5个样本= 175个样本),鉴别器模型的采样错误(3个样本不是batch_size)

这是每次训练迭代中两次采样的错误。可以避免这种情况吗?

0 个答案:

没有答案