索引1超出了轴0的范围,其大小为/ validation epoch

时间:2019-04-29 13:06:45

标签: python pytorch

我正在使用带有pytorch的Unet来分割血管,当我训练我的网络时,它压在了应该进行验证的第5个时代

def image_concatenate(image, crop_num1, crop_num2, dim1, dim2):
    """concatenate images
    Args :
        image : output images (should be square)
        crop_num2 (int) : number of crop in horizontal way (2)
        crop_num1 (int) : number of crop in vertical way (2)
        dim1(int) : vertical size of output (512)
        dim2(int) : horizontal size_of_output (512)
    Return :
        div_array : numpy arrays of numbers of 1,2,4
    """
    crop_size = image.shape[1]  # size of crop
    empty_array = np.zeros([dim1, dim2]).astype("float64")  # to make sure no overflow
    dim1_stride = stride_size(dim1, crop_num1, crop_size)  # vertical stride
    dim2_stride = stride_size(dim2, crop_num2, crop_size)  # horizontal stride
    index = 0
    for i in range(crop_num1):
        for j in range(crop_num2):
            # add image to empty_array at specific position
            empty_array[dim1_stride*i:dim1_stride*i+ crop_size,
                        dim2_stride*j:dim2_stride*j+ crop_size] += image[index]
            index += 1
    return empty_array

当训练发生时,这就是错误

Epoch 5 Train loss: 0.012208586327390733 Train acc 0.9599569108751085
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-69-d8ccfac27aaf> in <module>()
     43         # Validation every 5 epoch
     44         if (i+1) % 5 == 0:
---> 45             val_acc, val_loss = validate_model(model, SEM_val_load, criterion, i+1, True, image_save_path)
     46             print('Val loss:', val_loss, "val acc:", val_acc)
     47             values = [i+1, train_loss, train_acc, val_loss, val_acc]

2 frames
<ipython-input-61-7f3e547fab95> in image_concatenate(image, crop_num1, crop_num2, dim1, dim2)
    330             # add image to empty_array at specific position
    331             empty_array[dim1_stride*i:dim1_stride*i+ crop_size,
--> 332                         dim2_stride*j:dim2_stride*j+ crop_size] += image[index]
    333             index += 1
    334     return empty_array

IndexError: index 1 is out of bounds for axis 0 with size 1

0 个答案:

没有答案