Memory optimizations to allow bigger images

This commit is contained in:
Jairo Correa 2023-11-21 08:11:40 -03:00
parent 6ff06fa796
commit 0de7950ff2
5 changed files with 62 additions and 15 deletions

View File

@ -80,6 +80,7 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-cpu", action="store_true", help="To use the CPU for preview (slow).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -99,6 +100,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--memory-estimation-multiplier", type=float, default=-1, help="Multiplier for the memory estimation.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@ -217,16 +217,21 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1
max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
if steps > max_steps:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
@ -259,8 +264,10 @@ def attention_split(q, k, v, heads, mask=None):
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:

View File

@ -7,6 +7,7 @@ from einops import rearrange
from typing import Optional, Any
from comfy import model_management
from comfy.cli_args import args
import comfy.ops
if model_management.xformers_enabled_vae():
@ -165,9 +166,15 @@ def slice_attention(q, k, v):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1
max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
@ -186,8 +193,10 @@ def slice_attention(q, k, v):
break
except model_management.OOM_EXCEPTION as e:
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)

View File

@ -217,10 +217,21 @@ class VAE:
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
tile_size = 64
while tile_size >= 8:
overlap = tile_size // 4
print(f"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding with tile size {tile_size} and overlap {overlap}.")
try:
pixel_samples = self.decode_tiled_(samples_in, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 8
self.first_stage_model = self.first_stage_model.to(self.offload_device)
if pixel_samples is None:
raise e
finally:
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
@ -245,10 +256,21 @@ class VAE:
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
tile_size = 512
while tile_size >= 64:
overlap = tile_size // 8
print(f"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding with tile size {tile_size} and overlap {overlap}.")
try:
samples = self.encode_tiled_(pixel_samples, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 64
self.first_stage_model = self.first_stage_model.to(self.offload_device)
if samples is None:
raise e
finally:
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):

View File

@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
import comfy.utils
from comfy import model_management
MAX_PREVIEW_RESOLUTION = 512
@ -18,11 +19,12 @@ class LatentPreviewer:
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
def __init__(self, taesd, device):
self.taesd = taesd
self.device = device
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach()
x_sample = self.taesd.decoder(x0[:1].to(self.device))[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
@ -52,6 +54,8 @@ class Latent2RGBPreviewer(LatentPreviewer):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if args.preview_cpu:
device = torch.device("cpu")
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
@ -71,7 +75,7 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = TAESDPreviewerImpl(taesd, device)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
@ -94,7 +98,10 @@ def prepare_callback(model, steps, x0_output_dict=None):
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
try:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
except model_management.OOM_EXCEPTION as e:
pass
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback