Update utils.py, to fix upscaling error of non contiguous images on mac

Changes Made:
Added samples = samples.contiguous() at the start.
Added tile = (min(tile[0], samples.shape[2]), min(tile[1], samples.shape[3])) for dynamic tile sizing.
Modified output[b:b+1] = function(s.contiguous()).to(output_device) in the single-tile case.
Modified ps = function(s_in.contiguous()).to(output_device) in the tiled loop.
This commit is contained in:
Lightje 2025-05-31 02:50:50 +02:00 committed by GitHub
parent 08b7cc7506
commit ae86371991
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -875,6 +875,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
dims = len(tile)
samples = samples.contiguous() # Ensure input tensor is contiguous
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
@ -936,7 +937,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
output[b:b+1] = function(s.contiguous()).to(output_device) # Ensure single tile is contiguous
if pbar is not None:
pbar.update(1)
continue
@ -956,7 +957,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
ps = function(s_in.contiguous()).to(output_device) # Ensure tiled segment is contiguous
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
@ -982,7 +983,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
output[b:b+1] = out/out_div
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)