火炬:分区张量

时间:2017-04-15 20:06:16

标签: lua torch

我想将我的数据集(10,000张50x50 RGB图像)分割成两个数据集。类似的东西:

X = torch.rand(10000, 3, 50, 50)
inds = torch.randperm(X:size(1))[{ { 1, nTrain } }]:long()
X_selected = X:index(1, inds)
X_remaining = X:delete(1, inds)

无论我搜索什么,我都会收到Torch的GitHub文档。我怎么能这样做?

1 个答案:

答案 0 :(得分:1)

你可以试试这种方式

X = torch.rand(10000, 3, 50, 50)
inds = torch.randperm(X:size(1)):long()
train_inds = inds:narrow(1, 1, nTrain)
valid_inds = inds:narrow(1, nTrain + 1, X:size(1) - nTrain)
X_train = X:index(1, train_inds)
X_valid = X:index(1, valid_inds)