Add tiled VAE lane to MultiGPU Work Units

This commit is contained in:
John Pollock 2026-05-22 12:32:30 -05:00 committed by Jedrzej Kosinski
parent 425b8edc0b
commit 7e2bcb47a8
6 changed files with 366 additions and 21 deletions

View File

@ -182,18 +182,23 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
""" """
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus - 1] limit_extra_devices = full_extra_devices[:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
return upscale_model
cloned = copy.copy(upscale_model) cloned = copy.copy(upscale_model)
existing = getattr(upscale_model, 'multigpu_clones', None) existing = getattr(upscale_model, 'multigpu_clones', None)
clones: dict[torch.device, object] = dict(existing) if existing else {} limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
return cloned
for device in limit_extra_devices: for device in limit_extra_devices:
if device in clones: if device in clones:
continue continue
clone_desc = copy.deepcopy(upscale_model) clone_source = copy.copy(upscale_model)
if hasattr(clone_source, 'multigpu_clones'):
del clone_source.multigpu_clones
clone_desc = copy.deepcopy(clone_source)
clone_desc.model.eval() clone_desc.model.eval()
for p in clone_desc.model.parameters(): for p in clone_desc.model.parameters():
p.requires_grad_(False) p.requires_grad_(False)
@ -205,6 +210,53 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
return cloned return cloned
def create_vae_multigpu_deepclones(vae, max_gpus: int):
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
deepclones, one per extra CUDA device up to ``max_gpus``.
"""
vae.throw_exception_if_invalid()
vae_device = torch.device(vae.device)
cloned = copy.copy(vae)
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
if vae_device.type == "cpu":
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
return cloned
full_extra_devices = comfy.model_management.get_all_torch_devices()
def is_vae_device(device):
return device.type == vae_device.type and device.index == vae_device.index
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
return cloned
existing = getattr(vae, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
for device in limit_extra_devices:
if device in clones:
continue
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
clone_vae = copy.copy(vae)
if hasattr(clone_vae, 'multigpu_clones'):
del clone_vae.multigpu_clones
clone_vae.first_stage_model = cloned_patcher.model
clone_vae.patcher = cloned_patcher
clone_vae.first_stage_model.eval()
for p in clone_vae.first_stage_model.parameters():
p.requires_grad_(False)
clone_vae.first_stage_model.to("cpu")
clones[device] = clone_vae
logging.info(f"Created CPU VAE deepclone for {device}")
cloned.multigpu_clones = clones
return cloned
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'

View File

@ -972,6 +972,26 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0)
return output
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -981,16 +1001,49 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32): def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3: if samples.ndim == 3:
memory_shape = samples.shape
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
else: else:
og_shape = samples.shape og_shape = samples.shape
memory_shape = og_shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_decode_fn_factory(c, dev)
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -1000,6 +1053,25 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -1009,6 +1081,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1: if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
out_channels = self.latent_channels out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
else: else:
@ -1018,8 +1091,24 @@ class VAE:
overlap = overlap // extra_channel_size overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_encode_fn_factory(c, dev)
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
else:
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1: if self.latent_dim == 1:
return out return out
else: else:
@ -1027,6 +1116,21 @@ class VAE:
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
@ -1727,8 +1831,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if out[0] is not None: if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
return out return out
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory, embedding_directory=embedding_directory,
@ -1954,6 +2064,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model return model
def load_vae_patcher(vae_path, metadata=None, device=None):
"""Reload a VAE from disk and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh VAE patcher with no inherited source-device storage tracking. The
optional device matches the source loader's VAE initialization path; the
cloned patcher's load_device still controls the device targeted by the
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
carries dynamic-VRAM allocator state forward to the clone, which causes
per-device worker threads in tiled encode/decode dispatch to access weights
through the source-device buffer."""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype}) return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -1264,9 +1264,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
continue continue
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
all_positions = list(itertools.product(*positions)) split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
@ -1278,7 +1276,8 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
def worker(device, my_positions): def worker(device, my_positions):
try: try:
torch.cuda.set_device(device) if device.type == "cuda":
torch.cuda.set_device(device)
fn = functions[device] fn = functions[device]
local_buf = bufs[device] local_buf = bufs[device]
local_div = divs[device] local_div = divs[device]
@ -1307,17 +1306,24 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
o = local_buf o = local_buf
o_d = local_div o_d = local_div
ps_view = ps
mask_view = mask
for d in range(dims): for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o = o.narrow(d + 2, upscaled[d], l)
o_d = o_d.narrow(d + 2, upscaled[d], l)
if l < ps_view.shape[d + 2]:
ps_view = ps_view.narrow(d + 2, 0, l)
mask_view = mask_view.narrow(d + 2, 0, l)
o.add_(ps * mask) o.add_(ps_view * mask_view)
o_d.add_(mask) o_d.add_(mask_view)
if pbar is not None: if pbar is not None:
with pbar_lock: with pbar_lock:
pbar.update(1) pbar.update(1)
torch.cuda.synchronize(device) if device.type == "cuda":
torch.cuda.synchronize(device)
except BaseException as e: except BaseException as e:
with worker_lock: with worker_lock:
worker_errors.append(e) worker_errors.append(e)
@ -1331,7 +1337,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
raise worker_errors[0] raise worker_errors[0]
combined_buf = sum(bufs.values()) combined_buf = sum(bufs.values())
combined_div = sum(divs.values()).clamp_(min=1e-12) combined_div = sum(divs.values())
output[b:b+1] = combined_buf / combined_div output[b:b+1] = combined_buf / combined_div
return output return output

View File

@ -13,8 +13,8 @@ import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode): class MultiGPUCFGSplitNode(io.ComfyNode):
""" """
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so
nodes that recognize the attached state dispatch their work across multiple GPUs. downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
Place after nodes that modify the model object itself (compile, attention-switch, etc.). Place after nodes that modify the model object itself (compile, attention-switch, etc.).
Otherwise position is not order-sensitive. Otherwise position is not order-sensitive.
@ -30,21 +30,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
inputs=[ inputs=[
io.Model.Input("model", optional=True), io.Model.Input("model", optional=True),
io.UpscaleModel.Input("upscale_model", optional=True), io.UpscaleModel.Input("upscale_model", optional=True),
io.Vae.Input("vae", optional=True),
io.Int.Input("max_gpus", default=2, min=1, step=1), io.Int.Input("max_gpus", default=2, min=1, step=1),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),
io.UpscaleModel.Output(), io.UpscaleModel.Output(),
io.Vae.Output(),
], ],
) )
@classmethod @classmethod
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput:
if model is not None: if model is not None:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
if upscale_model is not None: if upscale_model is not None:
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
return io.NodeOutput(model, upscale_model) if vae is not None:
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
return io.NodeOutput(model, upscale_model, vae)
class MultiGPUOptionsNode(io.ComfyNode): class MultiGPUOptionsNode(io.ComfyNode):

View File

@ -869,6 +869,7 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name, device="default"): def load_vae(self, vae_name, device="default"):
metadata = None metadata = None
vae_path = None
if vae_name == "pixel_space": if vae_name == "pixel_space":
sd = {} sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0) sd["pixel_space_vae"] = torch.tensor(1.0)
@ -888,6 +889,11 @@ class VAELoader:
resolved = comfy.model_management.resolve_gpu_device_option(device) resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
# Register a reload factory on the patcher so MultiGPU work-units can use
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
if vae_path is not None:
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:

View File

@ -0,0 +1,147 @@
import importlib
import sys
import types
import torch
import comfy.utils
def install_fake_comfy_aimdo(monkeypatch):
package = types.ModuleType("comfy_aimdo")
package.__path__ = []
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
module = types.ModuleType(f"comfy_aimdo.{name}")
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
setattr(package, name, module)
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
scale = 1.1
def upscale(a):
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
samples = torch.ones((1, 1, 11))
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
actual = comfy.utils.tiled_scale_multidim_multigpu(
samples,
{device: upscale for device in devices},
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
expected = comfy.utils.tiled_scale_multidim(
samples,
upscale,
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
assert actual.shape == expected.shape == (1, 1, 12)
torch.testing.assert_close(actual, expected)
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
class FakeModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.ones(1))
def eval(self):
return self
def parameters(self):
return [self.param]
class FakeDescriptor:
def __init__(self):
self.model = FakeModel()
self.device = None
def to(self, device):
self.device = device
return self
first_device = torch.device("cpu:0")
second_device = torch.device("cpu:1")
stale_device = torch.device("cpu:2")
existing_clone = FakeDescriptor()
stale_clone = FakeDescriptor()
source = FakeDescriptor()
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
fake_model_management = types.ModuleType("comfy.model_management")
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
import comfy
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
import comfy.multigpu
importlib.reload(comfy.multigpu)
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
assert cloned is not source
assert cloned.multigpu_clones[first_device] is existing_clone
assert stale_device not in cloned.multigpu_clones
assert second_device in cloned.multigpu_clones
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
assert cloned.multigpu_clones[second_device].device == "cpu"
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
assert single_gpu_clone is not source
assert not hasattr(single_gpu_clone, "multigpu_clones")
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
class FakeVAE:
def __init__(self):
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
vae = FakeVAE()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, vae, None),
)
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
placeholder_vae = types.SimpleNamespace()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
)
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config