mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 16:07:30 +08:00
Add tiled VAE lane to MultiGPU Work Units
This commit is contained in:
parent
74b0a826ea
commit
4d3d68e473
@ -182,18 +182,23 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
||||
"""
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus - 1]
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
||||
return upscale_model
|
||||
|
||||
cloned = copy.copy(upscale_model)
|
||||
existing = getattr(upscale_model, 'multigpu_clones', None)
|
||||
clones: dict[torch.device, object] = dict(existing) if existing else {}
|
||||
limit_extra_device_set = set(limit_extra_devices)
|
||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
||||
if hasattr(cloned, 'multigpu_clones'):
|
||||
del cloned.multigpu_clones
|
||||
return cloned
|
||||
|
||||
for device in limit_extra_devices:
|
||||
if device in clones:
|
||||
continue
|
||||
clone_desc = copy.deepcopy(upscale_model)
|
||||
clone_source = copy.copy(upscale_model)
|
||||
if hasattr(clone_source, 'multigpu_clones'):
|
||||
del clone_source.multigpu_clones
|
||||
clone_desc = copy.deepcopy(clone_source)
|
||||
clone_desc.model.eval()
|
||||
for p in clone_desc.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
@ -205,6 +210,53 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
||||
return cloned
|
||||
|
||||
|
||||
def create_vae_multigpu_deepclones(vae, max_gpus: int):
|
||||
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
|
||||
deepclones, one per extra CUDA device up to ``max_gpus``.
|
||||
"""
|
||||
vae.throw_exception_if_invalid()
|
||||
vae_device = torch.device(vae.device)
|
||||
cloned = copy.copy(vae)
|
||||
if hasattr(cloned, 'multigpu_clones'):
|
||||
del cloned.multigpu_clones
|
||||
if vae_device.type == "cpu":
|
||||
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
|
||||
return cloned
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices()
|
||||
|
||||
def is_vae_device(device):
|
||||
return device.type == vae_device.type and device.index == vae_device.index
|
||||
|
||||
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
|
||||
return cloned
|
||||
|
||||
existing = getattr(vae, 'multigpu_clones', None)
|
||||
limit_extra_device_set = set(limit_extra_devices)
|
||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
||||
|
||||
for device in limit_extra_devices:
|
||||
if device in clones:
|
||||
continue
|
||||
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
|
||||
clone_vae = copy.copy(vae)
|
||||
if hasattr(clone_vae, 'multigpu_clones'):
|
||||
del clone_vae.multigpu_clones
|
||||
clone_vae.first_stage_model = cloned_patcher.model
|
||||
clone_vae.patcher = cloned_patcher
|
||||
clone_vae.first_stage_model.eval()
|
||||
for p in clone_vae.first_stage_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
clone_vae.first_stage_model.to("cpu")
|
||||
clones[device] = clone_vae
|
||||
logging.info(f"Created CPU VAE deepclone for {device}")
|
||||
|
||||
cloned.multigpu_clones = clones
|
||||
return cloned
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
|
||||
132
comfy/sd.py
132
comfy/sd.py
@ -972,6 +972,26 @@ class VAE:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@ -981,16 +1001,49 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
memory_shape = samples.shape
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
memory_shape = og_shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = clone_decode_fn_factory(c, dev)
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -1000,6 +1053,25 @@ class VAE:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -1009,6 +1081,7 @@ class VAE:
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@ -1018,8 +1091,24 @@ class VAE:
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = clone_encode_fn_factory(c, dev)
|
||||
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
else:
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
@ -1027,6 +1116,21 @@ class VAE:
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@ -1727,8 +1831,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
|
||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
|
||||
return out
|
||||
|
||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||
return vae.patcher
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
@ -1954,6 +2064,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
def load_vae_patcher(vae_path, metadata=None, device=None):
|
||||
"""Reload a VAE from disk and return its patcher.
|
||||
|
||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
|
||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||
fresh VAE patcher with no inherited source-device storage tracking. The
|
||||
optional device matches the source loader's VAE initialization path; the
|
||||
cloned patcher's load_device still controls the device targeted by the
|
||||
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
|
||||
carries dynamic-VRAM allocator state forward to the clone, which causes
|
||||
per-device worker threads in tiled encode/decode dispatch to access weights
|
||||
through the source-device buffer."""
|
||||
if metadata is None:
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
else:
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return vae.patcher
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
@ -1263,9 +1263,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
|
||||
continue
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
all_positions = list(itertools.product(*positions))
|
||||
|
||||
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
|
||||
split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
|
||||
|
||||
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
|
||||
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
|
||||
@ -1277,7 +1275,8 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
|
||||
|
||||
def worker(device, my_positions):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
fn = functions[device]
|
||||
local_buf = bufs[device]
|
||||
local_div = divs[device]
|
||||
@ -1306,17 +1305,24 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
|
||||
|
||||
o = local_buf
|
||||
o_d = local_div
|
||||
ps_view = ps
|
||||
mask_view = mask
|
||||
for d in range(dims):
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
|
||||
o = o.narrow(d + 2, upscaled[d], l)
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], l)
|
||||
if l < ps_view.shape[d + 2]:
|
||||
ps_view = ps_view.narrow(d + 2, 0, l)
|
||||
mask_view = mask_view.narrow(d + 2, 0, l)
|
||||
|
||||
o.add_(ps * mask)
|
||||
o_d.add_(mask)
|
||||
o.add_(ps_view * mask_view)
|
||||
o_d.add_(mask_view)
|
||||
|
||||
if pbar is not None:
|
||||
with pbar_lock:
|
||||
pbar.update(1)
|
||||
torch.cuda.synchronize(device)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
except BaseException as e:
|
||||
with worker_lock:
|
||||
worker_errors.append(e)
|
||||
@ -1330,7 +1336,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
|
||||
raise worker_errors[0]
|
||||
|
||||
combined_buf = sum(bufs.values())
|
||||
combined_div = sum(divs.values()).clamp_(min=1e-12)
|
||||
combined_div = sum(divs.values())
|
||||
output[b:b+1] = combined_buf / combined_div
|
||||
|
||||
return output
|
||||
|
||||
@ -13,8 +13,8 @@ import comfy.multigpu
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream
|
||||
nodes that recognize the attached state dispatch their work across multiple GPUs.
|
||||
Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so
|
||||
downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
|
||||
|
||||
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
|
||||
Otherwise position is not order-sensitive.
|
||||
@ -30,21 +30,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Model.Input("model", optional=True),
|
||||
io.UpscaleModel.Input("upscale_model", optional=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.UpscaleModel.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput:
|
||||
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput:
|
||||
if model is not None:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
if upscale_model is not None:
|
||||
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
|
||||
return io.NodeOutput(model, upscale_model)
|
||||
if vae is not None:
|
||||
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
|
||||
return io.NodeOutput(model, upscale_model, vae)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -869,6 +869,7 @@ class VAELoader:
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name, device="default"):
|
||||
metadata = None
|
||||
vae_path = None
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||
@ -888,6 +889,11 @@ class VAELoader:
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
||||
vae.throw_exception_if_invalid()
|
||||
# Register a reload factory on the patcher so MultiGPU work-units can use
|
||||
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
|
||||
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
|
||||
if vae_path is not None:
|
||||
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
|
||||
147
tests-unit/comfy_test/multigpu_test.py
Normal file
147
tests-unit/comfy_test/multigpu_test.py
Normal file
@ -0,0 +1,147 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def install_fake_comfy_aimdo(monkeypatch):
|
||||
package = types.ModuleType("comfy_aimdo")
|
||||
package.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
|
||||
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
|
||||
module = types.ModuleType(f"comfy_aimdo.{name}")
|
||||
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
|
||||
setattr(package, name, module)
|
||||
|
||||
|
||||
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
|
||||
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
|
||||
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
|
||||
|
||||
scale = 1.1
|
||||
|
||||
def upscale(a):
|
||||
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
|
||||
|
||||
samples = torch.ones((1, 1, 11))
|
||||
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
|
||||
|
||||
actual = comfy.utils.tiled_scale_multidim_multigpu(
|
||||
samples,
|
||||
{device: upscale for device in devices},
|
||||
tile=(5,),
|
||||
overlap=2,
|
||||
upscale_amount=scale,
|
||||
out_channels=1,
|
||||
output_device="cpu",
|
||||
)
|
||||
expected = comfy.utils.tiled_scale_multidim(
|
||||
samples,
|
||||
upscale,
|
||||
tile=(5,),
|
||||
overlap=2,
|
||||
upscale_amount=scale,
|
||||
out_channels=1,
|
||||
output_device="cpu",
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape == (1, 1, 12)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
|
||||
class FakeModel:
|
||||
def __init__(self):
|
||||
self.param = torch.nn.Parameter(torch.ones(1))
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def parameters(self):
|
||||
return [self.param]
|
||||
|
||||
class FakeDescriptor:
|
||||
def __init__(self):
|
||||
self.model = FakeModel()
|
||||
self.device = None
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
first_device = torch.device("cpu:0")
|
||||
second_device = torch.device("cpu:1")
|
||||
stale_device = torch.device("cpu:2")
|
||||
existing_clone = FakeDescriptor()
|
||||
stale_clone = FakeDescriptor()
|
||||
source = FakeDescriptor()
|
||||
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
|
||||
fake_model_management = types.ModuleType("comfy.model_management")
|
||||
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
|
||||
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
|
||||
import comfy
|
||||
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
|
||||
import comfy.multigpu
|
||||
importlib.reload(comfy.multigpu)
|
||||
|
||||
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
|
||||
|
||||
assert cloned is not source
|
||||
assert cloned.multigpu_clones[first_device] is existing_clone
|
||||
assert stale_device not in cloned.multigpu_clones
|
||||
assert second_device in cloned.multigpu_clones
|
||||
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
|
||||
assert cloned.multigpu_clones[second_device].device == "cpu"
|
||||
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
|
||||
|
||||
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
|
||||
assert single_gpu_clone is not source
|
||||
assert not hasattr(single_gpu_clone, "multigpu_clones")
|
||||
|
||||
|
||||
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
|
||||
install_fake_comfy_aimdo(monkeypatch)
|
||||
import comfy.sd
|
||||
importlib.reload(comfy.sd)
|
||||
|
||||
class FakeVAE:
|
||||
def __init__(self):
|
||||
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
|
||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
vae = FakeVAE()
|
||||
metadata = {"format": "checkpoint"}
|
||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
||||
monkeypatch.setattr(
|
||||
comfy.sd,
|
||||
"load_state_dict_guess_config",
|
||||
lambda *args, **kwargs: (model_patcher, None, vae, None),
|
||||
)
|
||||
|
||||
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
|
||||
|
||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
||||
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
|
||||
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
|
||||
|
||||
|
||||
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
|
||||
install_fake_comfy_aimdo(monkeypatch)
|
||||
import comfy.sd
|
||||
importlib.reload(comfy.sd)
|
||||
|
||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
placeholder_vae = types.SimpleNamespace()
|
||||
metadata = {"format": "checkpoint"}
|
||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
||||
monkeypatch.setattr(
|
||||
comfy.sd,
|
||||
"load_state_dict_guess_config",
|
||||
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
|
||||
)
|
||||
|
||||
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
|
||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
||||
Loading…
Reference in New Issue
Block a user