From 4d3d68e4731cf366289f9f4ca11242f4a78956df Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 22 May 2026 12:32:30 -0500 Subject: [PATCH] Add tiled VAE lane to MultiGPU Work Units --- comfy/multigpu.py | 64 ++++++++++- comfy/sd.py | 132 +++++++++++++++++++++- comfy/utils.py | 26 +++-- comfy_extras/nodes_multigpu.py | 12 +- nodes.py | 6 + tests-unit/comfy_test/multigpu_test.py | 147 +++++++++++++++++++++++++ 6 files changed, 366 insertions(+), 21 deletions(-) create mode 100644 tests-unit/comfy_test/multigpu_test.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 7f90b7db7..2573185de 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -182,18 +182,23 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): """ full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) limit_extra_devices = full_extra_devices[:max_gpus - 1] - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") - return upscale_model - cloned = copy.copy(upscale_model) existing = getattr(upscale_model, 'multigpu_clones', None) - clones: dict[torch.device, object] = dict(existing) if existing else {} + limit_extra_device_set = set(limit_extra_devices) + clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") + if hasattr(cloned, 'multigpu_clones'): + del cloned.multigpu_clones + return cloned for device in limit_extra_devices: if device in clones: continue - clone_desc = copy.deepcopy(upscale_model) + clone_source = copy.copy(upscale_model) + if hasattr(clone_source, 'multigpu_clones'): + del clone_source.multigpu_clones + clone_desc = copy.deepcopy(clone_source) clone_desc.model.eval() for p in clone_desc.model.parameters(): p.requires_grad_(False) @@ -205,6 +210,53 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): return cloned +def create_vae_multigpu_deepclones(vae, max_gpus: int): + """Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE + deepclones, one per extra CUDA device up to ``max_gpus``. + """ + vae.throw_exception_if_invalid() + vae_device = torch.device(vae.device) + cloned = copy.copy(vae) + if hasattr(cloned, 'multigpu_clones'): + del cloned.multigpu_clones + if vae_device.type == "cpu": + logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.") + return cloned + + full_extra_devices = comfy.model_management.get_all_torch_devices() + + def is_vae_device(device): + return device.type == vae_device.type and device.index == vae_device.index + + limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1] + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.") + return cloned + + existing = getattr(vae, 'multigpu_clones', None) + limit_extra_device_set = set(limit_extra_devices) + clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} + + for device in limit_extra_devices: + if device in clones: + continue + cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device) + clone_vae = copy.copy(vae) + if hasattr(clone_vae, 'multigpu_clones'): + del clone_vae.multigpu_clones + clone_vae.first_stage_model = cloned_patcher.model + clone_vae.patcher = cloned_patcher + clone_vae.first_stage_model.eval() + for p in clone_vae.first_stage_model.parameters(): + p.requires_grad_(False) + clone_vae.first_stage_model.to("cpu") + clones[device] = clone_vae + logging.info(f"Created CPU VAE deepclone for {device}") + + cloned.multigpu_clones = clones + return cloned + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/sd.py b/comfy/sd.py index 1670a0486..6401fdb14 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -972,6 +972,26 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + output = self.process_output( + (comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + + comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + + comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar)) + / 3.0) + return output + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -981,16 +1001,49 @@ class VAE: def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: + memory_shape = samples.shape decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) else: og_shape = samples.shape + memory_shape = og_shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = clone_decode_fn_factory(c, dev) + return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -1000,6 +1053,25 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples /= 3.0 + return samples + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -1009,6 +1081,7 @@ class VAE: def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -1018,8 +1091,24 @@ class VAE: overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) + clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype())) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = clone_encode_fn_factory(c, dev) + out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + else: + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) - out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: return out else: @@ -1027,6 +1116,21 @@ class VAE: def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -1727,8 +1831,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + if output_vae and out[2] is not None and hasattr(out[2], "patcher"): + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic)) return out +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1954,6 +2064,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model +def load_vae_patcher(vae_path, metadata=None, device=None): + """Reload a VAE from disk and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh VAE patcher with no inherited source-device storage tracking. The + optional device matches the source loader's VAE initialization path; the + cloned patcher's load_device still controls the device targeted by the + multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper + carries dynamic-VRAM allocator state forward to the clone, which causes + per-device worker threads in tiled encode/decode dispatch to access weights + through the source-device buffer.""" + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index c53e0cb91..6b12676d2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1263,9 +1263,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, continue positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - all_positions = list(itertools.product(*positions)) - - split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))} + split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))} out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) @@ -1277,7 +1275,8 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, def worker(device, my_positions): try: - torch.cuda.set_device(device) + if device.type == "cuda": + torch.cuda.set_device(device) fn = functions[device] local_buf = bufs[device] local_div = divs[device] @@ -1306,17 +1305,24 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, o = local_buf o_d = local_div + ps_view = ps + mask_view = mask for d in range(dims): - o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) + o = o.narrow(d + 2, upscaled[d], l) + o_d = o_d.narrow(d + 2, upscaled[d], l) + if l < ps_view.shape[d + 2]: + ps_view = ps_view.narrow(d + 2, 0, l) + mask_view = mask_view.narrow(d + 2, 0, l) - o.add_(ps * mask) - o_d.add_(mask) + o.add_(ps_view * mask_view) + o_d.add_(mask_view) if pbar is not None: with pbar_lock: pbar.update(1) - torch.cuda.synchronize(device) + if device.type == "cuda": + torch.cuda.synchronize(device) except BaseException as e: with worker_lock: worker_errors.append(e) @@ -1330,7 +1336,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, raise worker_errors[0] combined_buf = sum(bufs.values()) - combined_div = sum(divs.values()).clamp_(min=1e-12) + combined_div = sum(divs.values()) output[b:b+1] = combined_buf / combined_div return output diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 021dfca3f..dd0f76798 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,8 +13,8 @@ import comfy.multigpu class MultiGPUCFGSplitNode(io.ComfyNode): """ - Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream - nodes that recognize the attached state dispatch their work across multiple GPUs. + Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so + downstream nodes that recognize the attached state dispatch their work across multiple GPUs. Place after nodes that modify the model object itself (compile, attention-switch, etc.). Otherwise position is not order-sensitive. @@ -30,21 +30,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode): inputs=[ io.Model.Input("model", optional=True), io.UpscaleModel.Input("upscale_model", optional=True), + io.Vae.Input("vae", optional=True), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), io.UpscaleModel.Output(), + io.Vae.Output(), ], ) @classmethod - def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: + def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput: if model is not None: model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) if upscale_model is not None: upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) - return io.NodeOutput(model, upscale_model) + if vae is not None: + vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus) + return io.NodeOutput(model, upscale_model, vae) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/nodes.py b/nodes.py index 2f3856330..9193e9ddb 100644 --- a/nodes.py +++ b/nodes.py @@ -869,6 +869,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name, device="default"): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -888,6 +889,11 @@ class VAELoader: resolved = comfy.model_management.resolve_gpu_device_option(device) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so MultiGPU work-units can use + # ModelPatcher.deepclone_multigpu to produce per-device clones from the + # same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader). + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved)) return (vae,) class ControlNetLoader: diff --git a/tests-unit/comfy_test/multigpu_test.py b/tests-unit/comfy_test/multigpu_test.py new file mode 100644 index 000000000..e7ba15df7 --- /dev/null +++ b/tests-unit/comfy_test/multigpu_test.py @@ -0,0 +1,147 @@ +import importlib +import sys +import types + +import torch + +import comfy.utils + + +def install_fake_comfy_aimdo(monkeypatch): + package = types.ModuleType("comfy_aimdo") + package.__path__ = [] + monkeypatch.setitem(sys.modules, "comfy_aimdo", package) + for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"): + module = types.ModuleType(f"comfy_aimdo.{name}") + monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module) + setattr(package, name, module) + + +def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch): + monkeypatch.setattr(torch.cuda, "set_device", lambda device: None) + monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None) + + scale = 1.1 + + def upscale(a): + return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device) + + samples = torch.ones((1, 1, 11)) + devices = [torch.device("cpu:0"), torch.device("cpu:1")] + + actual = comfy.utils.tiled_scale_multidim_multigpu( + samples, + {device: upscale for device in devices}, + tile=(5,), + overlap=2, + upscale_amount=scale, + out_channels=1, + output_device="cpu", + ) + expected = comfy.utils.tiled_scale_multidim( + samples, + upscale, + tile=(5,), + overlap=2, + upscale_amount=scale, + out_channels=1, + output_device="cpu", + ) + + assert actual.shape == expected.shape == (1, 1, 12) + torch.testing.assert_close(actual, expected) + + +def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch): + class FakeModel: + def __init__(self): + self.param = torch.nn.Parameter(torch.ones(1)) + + def eval(self): + return self + + def parameters(self): + return [self.param] + + class FakeDescriptor: + def __init__(self): + self.model = FakeModel() + self.device = None + + def to(self, device): + self.device = device + return self + + first_device = torch.device("cpu:0") + second_device = torch.device("cpu:1") + stale_device = torch.device("cpu:2") + existing_clone = FakeDescriptor() + stale_clone = FakeDescriptor() + source = FakeDescriptor() + source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone} + fake_model_management = types.ModuleType("comfy.model_management") + fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device] + monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management) + import comfy + monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False) + import comfy.multigpu + importlib.reload(comfy.multigpu) + + cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3) + + assert cloned is not source + assert cloned.multigpu_clones[first_device] is existing_clone + assert stale_device not in cloned.multigpu_clones + assert second_device in cloned.multigpu_clones + assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones") + assert cloned.multigpu_clones[second_device].device == "cpu" + assert not cloned.multigpu_clones[second_device].model.param.requires_grad + + single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1) + assert single_gpu_clone is not source + assert not hasattr(single_gpu_clone, "multigpu_clones") + + +def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch): + install_fake_comfy_aimdo(monkeypatch) + import comfy.sd + importlib.reload(comfy.sd) + + class FakeVAE: + def __init__(self): + self.patcher = types.SimpleNamespace(cached_patcher_init=None) + + model_patcher = types.SimpleNamespace(cached_patcher_init=None) + vae = FakeVAE() + metadata = {"format": "checkpoint"} + monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) + monkeypatch.setattr( + comfy.sd, + "load_state_dict_guess_config", + lambda *args, **kwargs: (model_patcher, None, vae, None), + ) + + comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True) + + assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config + assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher + assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors" + + +def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch): + install_fake_comfy_aimdo(monkeypatch) + import comfy.sd + importlib.reload(comfy.sd) + + model_patcher = types.SimpleNamespace(cached_patcher_init=None) + placeholder_vae = types.SimpleNamespace() + metadata = {"format": "checkpoint"} + monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) + monkeypatch.setattr( + comfy.sd, + "load_state_dict_guess_config", + lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None), + ) + + assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae + assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config