mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Memory optimizations to allow bigger images
This commit is contained in:
parent
6ff06fa796
commit
0de7950ff2
@ -80,6 +80,7 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
TAESD = "taesd"
|
TAESD = "taesd"
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
parser.add_argument("--preview-cpu", action="store_true", help="To use the CPU for preview (slow).")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
@ -99,6 +100,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
|||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
|
|
||||||
|
parser.add_argument("--memory-estimation-multiplier", type=float, default=-1, help="Multiplier for the memory estimation.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
|||||||
@ -217,16 +217,21 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||||
modifier = 3
|
modifier = 3
|
||||||
|
if args.memory_estimation_multiplier >= 0:
|
||||||
|
modifier = args.memory_estimation_multiplier
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
|
max_steps = q.shape[1] - 1
|
||||||
|
while (q.shape[1] % max_steps) != 0:
|
||||||
|
max_steps -= 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > max_steps:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||||
@ -259,8 +264,10 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
cleared_cache = True
|
cleared_cache = True
|
||||||
print("out of memory error, emptying cache and trying again")
|
print("out of memory error, emptying cache and trying again")
|
||||||
continue
|
continue
|
||||||
steps *= 2
|
steps += 1
|
||||||
if steps > 64:
|
while (q.shape[1] % steps) != 0 and steps < max_steps:
|
||||||
|
steps += 1
|
||||||
|
if steps > max_steps:
|
||||||
raise e
|
raise e
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from einops import rearrange
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
@ -165,9 +166,15 @@ def slice_attention(q, k, v):
|
|||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
if args.memory_estimation_multiplier >= 0:
|
||||||
|
modifier = args.memory_estimation_multiplier
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
|
max_steps = q.shape[1] - 1
|
||||||
|
while (q.shape[1] % max_steps) != 0:
|
||||||
|
max_steps -= 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
@ -186,8 +193,10 @@ def slice_attention(q, k, v):
|
|||||||
break
|
break
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
steps *= 2
|
steps += 1
|
||||||
if steps > 128:
|
while (q.shape[1] % steps) != 0 and steps < max_steps:
|
||||||
|
steps += 1
|
||||||
|
if steps > max_steps:
|
||||||
raise e
|
raise e
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
|
||||||
|
|||||||
30
comfy/sd.py
30
comfy/sd.py
@ -217,9 +217,20 @@ class VAE:
|
|||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
tile_size = 64
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
while tile_size >= 8:
|
||||||
|
overlap = tile_size // 4
|
||||||
|
print(f"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding with tile size {tile_size} and overlap {overlap}.")
|
||||||
|
try:
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
pass
|
||||||
|
tile_size -= 8
|
||||||
|
|
||||||
|
if pixel_samples is None:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -245,9 +256,20 @@ class VAE:
|
|||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
tile_size = 512
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
while tile_size >= 64:
|
||||||
|
overlap = tile_size // 8
|
||||||
|
print(f"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding with tile size {tile_size} and overlap {overlap}.")
|
||||||
|
try:
|
||||||
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
pass
|
||||||
|
tile_size -= 64
|
||||||
|
|
||||||
|
if samples is None:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod
|
|||||||
from comfy.taesd.taesd import TAESD
|
from comfy.taesd.taesd import TAESD
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy import model_management
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = 512
|
||||||
|
|
||||||
@ -18,11 +19,12 @@ class LatentPreviewer:
|
|||||||
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||||
|
|
||||||
class TAESDPreviewerImpl(LatentPreviewer):
|
class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
def __init__(self, taesd):
|
def __init__(self, taesd, device):
|
||||||
self.taesd = taesd
|
self.taesd = taesd
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
x_sample = self.taesd.decoder(x0[:1])[0].detach()
|
x_sample = self.taesd.decoder(x0[:1].to(self.device))[0].detach()
|
||||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
|
|
||||||
@ -52,6 +54,8 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
def get_previewer(device, latent_format):
|
def get_previewer(device, latent_format):
|
||||||
previewer = None
|
previewer = None
|
||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
|
if args.preview_cpu:
|
||||||
|
device = torch.device("cpu")
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = None
|
taesd_decoder_path = None
|
||||||
@ -71,7 +75,7 @@ def get_previewer(device, latent_format):
|
|||||||
if method == LatentPreviewMethod.TAESD:
|
if method == LatentPreviewMethod.TAESD:
|
||||||
if taesd_decoder_path:
|
if taesd_decoder_path:
|
||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd, device)
|
||||||
else:
|
else:
|
||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
@ -94,7 +98,10 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
|||||||
|
|
||||||
preview_bytes = None
|
preview_bytes = None
|
||||||
if previewer:
|
if previewer:
|
||||||
|
try:
|
||||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
pass
|
||||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user