pytorch的binarize输入

时间:2017-11-26 04:45:18

标签: input binary loader pytorch

请问如何加载pytorch中的数据一旦加载就变成二进制文件? 就像Tensorflow可以通过以下方式完成此任务:

train_data  = mnist.input_data.read_data_sets(data_directory, one_hot=True)

pytorch如何实现one_hot=True效果。

我现在拥有的data_loader是:

torch.set_default_tensor_type('torch.FloatTensor')
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data/', train=True, download=True,
                       transform=transforms.Compose([
                         # transforms.RandomHorizontalFlip(),
                           transforms.ToTensor()])),
        batch_size=batch_size, shuffle=False)

我想让train_loader中的数据被二值化。 现在我要做的是:加载数据后,

for data,_ in train_loader:
    torch.round(data) 
    data = Variable(data)

使用torch.round()功能。这是对的吗?

1 个答案:

答案 0 :(得分:0)

单热编码思想用于分类。听起来你可能正在尝试创建一个自动编码器。

如果您正在创建自动编码器,则无需舍入,因为BCELoss可以处理0到1之间的值。请注意,在训练时最好不要应用sigmoid而是使用BCELossWithLogits,因为它提供数值稳定性。

以下是带有MNIST的autoencoder

的示例

如果您正在尝试进行分类,那么就不需要一个热矢量,只需输出等于类数的神经元数量,即MNIST输出10个神经元,然后将其传递给{{3} }以及具有相应预期类值的LongTensor

以下是MNIST上CrossEntropyLoss的示例