Reduce tiled decode peak memory (#13050)

This commit is contained in:
Jukka Seppänen 2026-03-19 19:29:34 +02:00 committed by GitHub
parent ab14541ef7
commit fd0261d2bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1) pbar.update(1)
continue continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos))) upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
for d in range(2, dims + 2): for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2])) feather = round(get_scale(d - 2, overlap[d - 2]))
@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)
output[b:b+1] = out/out_div out.div_(out_div)
return output return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):