这些功能将图像张量分割为相等大小的重叠图块。目前,行/列中的最终图块的x / y重叠值可能会有所不同。我需要所有x重叠值都相同,并且所有y重叠值都相同,但是我尝试解决的所有问题似乎都没有。
import torch
def calc_tiles(d, td):
num_tiles = d // td + 1
overlap = (d - td ) // (num_tiles - 1)
tile_idx = [x * overlap for x in range(0, num_tiles)]
return tile_idx, num_tiles, overlap
def split_tensor_equal(tensor, tile_size=256, offset_x=0, offset_y=0):
tensor = tensor.clone()
tile_h, tile_w = tile_size +offset_y, tile_size +offset_x
h, w = tensor.size(2), tensor.size(3)
rows, ovlp = [0,0], [0,0]
x_idx, rows[0], ovlp[0] = calc_tiles(w, tile_w)
y_idx, rows[1], ovlp[1] = calc_tiles(h, tile_h)
tile_list = []
x_max, y_max = 0, 0
for y in y_idx:
y_max += tile_h
y_min = y_max-tile_h
for x in x_idx:
x_max += tile_w
x_min = x_max-tile_w
if x_max > w:
x_max = w
x_min = w-tile_w
if y_max > h:
y_max = h
y_min = h-tile_h
#print(y_min,y_max, x_min,x_max)
tile = tensor[:, :, y_min:y_max, x_min:x_max]
tile_list.append(tile)
x_max = tile_w
return tile_list, rows, ovlp
test_input = torch.randn(3, 1024, 768).unsqueeze(0)
split_tensor_equal(test_input, tile_size=560)