mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 07:57:29 +08:00
Defer @pollockjj's tiled-VAE and UPSCALE_MODEL MultiGPU lanes (#14066)
* Revert "Add tiled VAE lane to MultiGPU Work Units" This reverts commit4d3d68e473. The tiled VAE lane will land as part of a follow-up PR alongside the UPSCALE_MODEL lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. * Revert "Add UPSCALE_MODEL lane to MultiGPU CFG Split" This reverts commit74b0a826ea. The UPSCALE_MODEL lane will land as part of a follow-up PR alongside the tiled VAE lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. --------- Co-authored-by: John Pollock <pollockjj@gmail.com>
This commit is contained in:
parent
cb83c41db7
commit
5dc4e38b89
@ -1,5 +1,4 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import copy
|
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import torch
|
import torch
|
||||||
@ -176,87 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
|
||||||
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
|
|
||||||
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
|
|
||||||
"""
|
|
||||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
|
||||||
limit_extra_devices = full_extra_devices[:max_gpus - 1]
|
|
||||||
cloned = copy.copy(upscale_model)
|
|
||||||
existing = getattr(upscale_model, 'multigpu_clones', None)
|
|
||||||
limit_extra_device_set = set(limit_extra_devices)
|
|
||||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
|
||||||
if len(limit_extra_devices) == 0:
|
|
||||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
|
||||||
if hasattr(cloned, 'multigpu_clones'):
|
|
||||||
del cloned.multigpu_clones
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
for device in limit_extra_devices:
|
|
||||||
if device in clones:
|
|
||||||
continue
|
|
||||||
clone_source = copy.copy(upscale_model)
|
|
||||||
if hasattr(clone_source, 'multigpu_clones'):
|
|
||||||
del clone_source.multigpu_clones
|
|
||||||
clone_desc = copy.deepcopy(clone_source)
|
|
||||||
clone_desc.model.eval()
|
|
||||||
for p in clone_desc.model.parameters():
|
|
||||||
p.requires_grad_(False)
|
|
||||||
clone_desc.to("cpu")
|
|
||||||
clones[device] = clone_desc
|
|
||||||
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
|
|
||||||
|
|
||||||
cloned.multigpu_clones = clones
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
|
|
||||||
def create_vae_multigpu_deepclones(vae, max_gpus: int):
|
|
||||||
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
|
|
||||||
deepclones, one per extra CUDA device up to ``max_gpus``.
|
|
||||||
"""
|
|
||||||
vae.throw_exception_if_invalid()
|
|
||||||
vae_device = torch.device(vae.device)
|
|
||||||
cloned = copy.copy(vae)
|
|
||||||
if hasattr(cloned, 'multigpu_clones'):
|
|
||||||
del cloned.multigpu_clones
|
|
||||||
if vae_device.type == "cpu":
|
|
||||||
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
full_extra_devices = comfy.model_management.get_all_torch_devices()
|
|
||||||
|
|
||||||
def is_vae_device(device):
|
|
||||||
return device.type == vae_device.type and device.index == vae_device.index
|
|
||||||
|
|
||||||
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
|
|
||||||
if len(limit_extra_devices) == 0:
|
|
||||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
existing = getattr(vae, 'multigpu_clones', None)
|
|
||||||
limit_extra_device_set = set(limit_extra_devices)
|
|
||||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
|
||||||
|
|
||||||
for device in limit_extra_devices:
|
|
||||||
if device in clones:
|
|
||||||
continue
|
|
||||||
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
|
|
||||||
clone_vae = copy.copy(vae)
|
|
||||||
if hasattr(clone_vae, 'multigpu_clones'):
|
|
||||||
del clone_vae.multigpu_clones
|
|
||||||
clone_vae.first_stage_model = cloned_patcher.model
|
|
||||||
clone_vae.patcher = cloned_patcher
|
|
||||||
clone_vae.first_stage_model.eval()
|
|
||||||
for p in clone_vae.first_stage_model.parameters():
|
|
||||||
p.requires_grad_(False)
|
|
||||||
clone_vae.first_stage_model.to("cpu")
|
|
||||||
clones[device] = clone_vae
|
|
||||||
logging.info(f"Created CPU VAE deepclone for {device}")
|
|
||||||
|
|
||||||
cloned.multigpu_clones = clones
|
|
||||||
return cloned
|
|
||||||
|
|
||||||
|
|
||||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||||
|
|||||||
132
comfy/sd.py
132
comfy/sd.py
@ -972,26 +972,6 @@ class VAE:
|
|||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: decode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
|
||||||
output = self.process_output(
|
|
||||||
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
|
||||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
|
||||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
|
|
||||||
/ 3.0)
|
|
||||||
return output
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
|
|
||||||
output = self.process_output(
|
output = self.process_output(
|
||||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
@ -1001,49 +981,16 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||||
if samples.ndim == 3:
|
if samples.ndim == 3:
|
||||||
memory_shape = samples.shape
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
|
||||||
else:
|
else:
|
||||||
og_shape = samples.shape
|
og_shape = samples.shape
|
||||||
memory_shape = og_shape
|
|
||||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: decode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = clone_decode_fn_factory(c, dev)
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: decode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -1053,25 +1000,6 @@ class VAE:
|
|||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: encode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
|
||||||
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
|
||||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
|
||||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
|
||||||
samples /= 3.0
|
|
||||||
return samples
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
|
|
||||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
@ -1081,7 +1009,6 @@ class VAE:
|
|||||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
|
||||||
out_channels = self.latent_channels
|
out_channels = self.latent_channels
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
else:
|
else:
|
||||||
@ -1091,24 +1018,8 @@ class VAE:
|
|||||||
overlap = overlap // extra_channel_size
|
overlap = overlap // extra_channel_size
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
|
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: encode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = clone_encode_fn_factory(c, dev)
|
|
||||||
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
else:
|
|
||||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
|
||||||
|
|
||||||
|
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
@ -1116,21 +1027,6 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
functions = {self.device: encode_fn}
|
|
||||||
try:
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
|
||||||
c.first_stage_model.to(dev)
|
|
||||||
for dev, c in multigpu_clones.items():
|
|
||||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
|
||||||
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
|
||||||
finally:
|
|
||||||
for c in multigpu_clones.values():
|
|
||||||
c.first_stage_model.to("cpu")
|
|
||||||
|
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
@ -1831,14 +1727,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if out[0] is not None:
|
if out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||||
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
|
|
||||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
|
||||||
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
|
|
||||||
return vae.patcher
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||||
embedding_directory=embedding_directory,
|
embedding_directory=embedding_directory,
|
||||||
@ -2064,26 +1954,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
|||||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_vae_patcher(vae_path, metadata=None, device=None):
|
|
||||||
"""Reload a VAE from disk and return its patcher.
|
|
||||||
|
|
||||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
|
|
||||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
|
||||||
fresh VAE patcher with no inherited source-device storage tracking. The
|
|
||||||
optional device matches the source loader's VAE initialization path; the
|
|
||||||
cloned patcher's load_device still controls the device targeted by the
|
|
||||||
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
|
|
||||||
carries dynamic-VRAM allocator state forward to the clone, which causes
|
|
||||||
per-device worker threads in tiled encode/decode dispatch to access weights
|
|
||||||
through the source-device buffer."""
|
|
||||||
if metadata is None:
|
|
||||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
|
||||||
else:
|
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
|
||||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
|
||||||
vae.throw_exception_if_invalid()
|
|
||||||
return vae.patcher
|
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|||||||
157
comfy/utils.py
157
comfy/utils.py
@ -28,13 +28,13 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
import threading
|
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -1187,161 +1187,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
|
|
||||||
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
|
||||||
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
|
|
||||||
|
|
||||||
Round-robin dispatches tile positions across devices via threading. Each thread maintains
|
|
||||||
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
|
|
||||||
formula as the single-device path. Buffers are summed at the end, producing output that is
|
|
||||||
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
|
|
||||||
|
|
||||||
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
|
|
||||||
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
|
|
||||||
available at that granularity).
|
|
||||||
"""
|
|
||||||
devices = list(functions.keys())
|
|
||||||
if len(devices) < 2:
|
|
||||||
only_fn = next(iter(functions.values())) if functions else None
|
|
||||||
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
|
|
||||||
upscale_amount=upscale_amount, out_channels=out_channels,
|
|
||||||
output_device=output_device, downscale=downscale,
|
|
||||||
index_formulas=index_formulas, pbar=pbar)
|
|
||||||
|
|
||||||
dims = len(tile)
|
|
||||||
|
|
||||||
if not (isinstance(upscale_amount, (tuple, list))):
|
|
||||||
upscale_amount = [upscale_amount] * dims
|
|
||||||
if not (isinstance(overlap, (tuple, list))):
|
|
||||||
overlap = [overlap] * dims
|
|
||||||
if index_formulas is None:
|
|
||||||
index_formulas = upscale_amount
|
|
||||||
if not (isinstance(index_formulas, (tuple, list))):
|
|
||||||
index_formulas = [index_formulas] * dims
|
|
||||||
|
|
||||||
def get_upscale(dim, val):
|
|
||||||
up = upscale_amount[dim]
|
|
||||||
return up(val) if callable(up) else up * val
|
|
||||||
|
|
||||||
def get_downscale(dim, val):
|
|
||||||
up = upscale_amount[dim]
|
|
||||||
return up(val) if callable(up) else val / up
|
|
||||||
|
|
||||||
def get_upscale_pos(dim, val):
|
|
||||||
up = index_formulas[dim]
|
|
||||||
return up(val) if callable(up) else up * val
|
|
||||||
|
|
||||||
def get_downscale_pos(dim, val):
|
|
||||||
up = index_formulas[dim]
|
|
||||||
return up(val) if callable(up) else val / up
|
|
||||||
|
|
||||||
if downscale:
|
|
||||||
get_scale = get_downscale
|
|
||||||
get_pos = get_downscale_pos
|
|
||||||
else:
|
|
||||||
get_scale = get_upscale
|
|
||||||
get_pos = get_upscale_pos
|
|
||||||
|
|
||||||
def mult_list_upscale(a):
|
|
||||||
return [round(get_scale(i, a[i])) for i in range(len(a))]
|
|
||||||
|
|
||||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
|
||||||
merge_device = torch.device("cpu")
|
|
||||||
|
|
||||||
pbar_lock = threading.Lock() if pbar is not None else None
|
|
||||||
primary_device = devices[0]
|
|
||||||
|
|
||||||
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
|
|
||||||
|
|
||||||
for b in range(samples_staged.shape[0]):
|
|
||||||
s = samples_staged[b:b+1]
|
|
||||||
|
|
||||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
|
||||||
with torch.inference_mode():
|
|
||||||
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
|
|
||||||
if pbar is not None:
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
|
||||||
split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
|
|
||||||
|
|
||||||
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
|
|
||||||
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
|
|
||||||
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
|
|
||||||
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
|
|
||||||
|
|
||||||
worker_errors: list[BaseException] = []
|
|
||||||
worker_lock = threading.Lock()
|
|
||||||
|
|
||||||
def worker(device, my_positions):
|
|
||||||
try:
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
fn = functions[device]
|
|
||||||
local_buf = bufs[device]
|
|
||||||
local_div = divs[device]
|
|
||||||
with torch.inference_mode():
|
|
||||||
for it in my_positions:
|
|
||||||
s_in = s
|
|
||||||
upscaled = []
|
|
||||||
for d in range(dims):
|
|
||||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
|
||||||
upscaled.append(round(get_pos(d, pos)))
|
|
||||||
|
|
||||||
s_in_dev = s_in.to(device, non_blocking=True)
|
|
||||||
ps = fn(s_in_dev).to(merge_device)
|
|
||||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
|
|
||||||
|
|
||||||
for d in range(2, dims + 2):
|
|
||||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
|
||||||
if feather >= mask.shape[d]:
|
|
||||||
continue
|
|
||||||
for t in range(feather):
|
|
||||||
a = (t + 1) / feather
|
|
||||||
mask.narrow(d, t, 1).mul_(a)
|
|
||||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
|
||||||
|
|
||||||
o = local_buf
|
|
||||||
o_d = local_div
|
|
||||||
ps_view = ps
|
|
||||||
mask_view = mask
|
|
||||||
for d in range(dims):
|
|
||||||
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
|
|
||||||
o = o.narrow(d + 2, upscaled[d], l)
|
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], l)
|
|
||||||
if l < ps_view.shape[d + 2]:
|
|
||||||
ps_view = ps_view.narrow(d + 2, 0, l)
|
|
||||||
mask_view = mask_view.narrow(d + 2, 0, l)
|
|
||||||
|
|
||||||
o.add_(ps_view * mask_view)
|
|
||||||
o_d.add_(mask_view)
|
|
||||||
|
|
||||||
if pbar is not None:
|
|
||||||
with pbar_lock:
|
|
||||||
pbar.update(1)
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize(device)
|
|
||||||
except BaseException as e:
|
|
||||||
with worker_lock:
|
|
||||||
worker_errors.append(e)
|
|
||||||
|
|
||||||
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
if worker_errors:
|
|
||||||
raise worker_errors[0]
|
|
||||||
|
|
||||||
combined_buf = sum(bufs.values())
|
|
||||||
combined_div = sum(divs.values())
|
|
||||||
output[b:b+1] = combined_buf / combined_div
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def model_trange(*args, **kwargs):
|
def model_trange(*args, **kwargs):
|
||||||
if not comfy.memory_management.aimdo_enabled:
|
if not comfy.memory_management.aimdo_enabled:
|
||||||
return trange(*args, **kwargs)
|
return trange(*args, **kwargs)
|
||||||
|
|||||||
@ -13,42 +13,33 @@ import comfy.multigpu
|
|||||||
|
|
||||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so
|
Prepares model to have sampling accelerated via splitting work units.
|
||||||
downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
|
|
||||||
|
|
||||||
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
|
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||||
Otherwise position is not order-sensitive.
|
|
||||||
|
Other than those exceptions, this node can be placed in any order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="MultiGPU_WorkUnits",
|
node_id="MultiGPU_WorkUnits",
|
||||||
display_name="MultiGPU Work Units",
|
display_name="MultiGPU CFG Split",
|
||||||
category="advanced/multigpu",
|
category="advanced/multigpu",
|
||||||
description=cleandoc(cls.__doc__),
|
description=cleandoc(cls.__doc__),
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", optional=True),
|
io.Model.Input("model"),
|
||||||
io.UpscaleModel.Input("upscale_model", optional=True),
|
|
||||||
io.Vae.Input("vae", optional=True),
|
|
||||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
io.UpscaleModel.Output(),
|
|
||||||
io.Vae.Output(),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput:
|
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||||
if model is not None:
|
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
return io.NodeOutput(model)
|
||||||
if upscale_model is not None:
|
|
||||||
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
|
|
||||||
if vae is not None:
|
|
||||||
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
|
|
||||||
return io.NodeOutput(model, upscale_model, vae)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiGPUOptionsNode(io.ComfyNode):
|
class MultiGPUOptionsNode(io.ComfyNode):
|
||||||
|
|||||||
@ -81,33 +81,13 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
|
|
||||||
output_device = comfy.model_management.intermediate_device()
|
output_device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
|
|
||||||
if multigpu_clones:
|
|
||||||
for dev, desc in multigpu_clones.items():
|
|
||||||
model_management.free_memory(memory_required, dev)
|
|
||||||
desc.to(dev)
|
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
try:
|
try:
|
||||||
while oom:
|
while oom:
|
||||||
try:
|
try:
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
if multigpu_clones:
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||||
functions = {device: lambda a: upscale_model(a.float())}
|
|
||||||
for dev, desc in multigpu_clones.items():
|
|
||||||
functions[dev] = lambda a, d=desc: d(a.float())
|
|
||||||
s = comfy.utils.tiled_scale_multidim_multigpu(
|
|
||||||
in_img,
|
|
||||||
functions,
|
|
||||||
tile=(tile, tile),
|
|
||||||
overlap=overlap,
|
|
||||||
upscale_amount=upscale_model.scale,
|
|
||||||
pbar=pbar,
|
|
||||||
output_device=output_device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
|
||||||
oom = False
|
oom = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
@ -116,9 +96,6 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
upscale_model.to("cpu")
|
upscale_model.to("cpu")
|
||||||
if multigpu_clones:
|
|
||||||
for desc in multigpu_clones.values():
|
|
||||||
desc.to("cpu")
|
|
||||||
|
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|||||||
6
nodes.py
6
nodes.py
@ -869,7 +869,6 @@ class VAELoader:
|
|||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name, device="default"):
|
def load_vae(self, vae_name, device="default"):
|
||||||
metadata = None
|
metadata = None
|
||||||
vae_path = None
|
|
||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd = {}
|
||||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
@ -889,11 +888,6 @@ class VAELoader:
|
|||||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
# Register a reload factory on the patcher so MultiGPU work-units can use
|
|
||||||
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
|
|
||||||
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
|
|
||||||
if vae_path is not None:
|
|
||||||
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
|
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
|
|||||||
@ -1,147 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
|
|
||||||
def install_fake_comfy_aimdo(monkeypatch):
|
|
||||||
package = types.ModuleType("comfy_aimdo")
|
|
||||||
package.__path__ = []
|
|
||||||
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
|
|
||||||
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
|
|
||||||
module = types.ModuleType(f"comfy_aimdo.{name}")
|
|
||||||
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
|
|
||||||
setattr(package, name, module)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
|
|
||||||
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
|
|
||||||
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
|
|
||||||
|
|
||||||
scale = 1.1
|
|
||||||
|
|
||||||
def upscale(a):
|
|
||||||
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
|
|
||||||
|
|
||||||
samples = torch.ones((1, 1, 11))
|
|
||||||
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
|
|
||||||
|
|
||||||
actual = comfy.utils.tiled_scale_multidim_multigpu(
|
|
||||||
samples,
|
|
||||||
{device: upscale for device in devices},
|
|
||||||
tile=(5,),
|
|
||||||
overlap=2,
|
|
||||||
upscale_amount=scale,
|
|
||||||
out_channels=1,
|
|
||||||
output_device="cpu",
|
|
||||||
)
|
|
||||||
expected = comfy.utils.tiled_scale_multidim(
|
|
||||||
samples,
|
|
||||||
upscale,
|
|
||||||
tile=(5,),
|
|
||||||
overlap=2,
|
|
||||||
upscale_amount=scale,
|
|
||||||
out_channels=1,
|
|
||||||
output_device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert actual.shape == expected.shape == (1, 1, 12)
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
|
|
||||||
class FakeModel:
|
|
||||||
def __init__(self):
|
|
||||||
self.param = torch.nn.Parameter(torch.ones(1))
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def parameters(self):
|
|
||||||
return [self.param]
|
|
||||||
|
|
||||||
class FakeDescriptor:
|
|
||||||
def __init__(self):
|
|
||||||
self.model = FakeModel()
|
|
||||||
self.device = None
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.device = device
|
|
||||||
return self
|
|
||||||
|
|
||||||
first_device = torch.device("cpu:0")
|
|
||||||
second_device = torch.device("cpu:1")
|
|
||||||
stale_device = torch.device("cpu:2")
|
|
||||||
existing_clone = FakeDescriptor()
|
|
||||||
stale_clone = FakeDescriptor()
|
|
||||||
source = FakeDescriptor()
|
|
||||||
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
|
|
||||||
fake_model_management = types.ModuleType("comfy.model_management")
|
|
||||||
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
|
|
||||||
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
|
|
||||||
import comfy
|
|
||||||
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
|
|
||||||
import comfy.multigpu
|
|
||||||
importlib.reload(comfy.multigpu)
|
|
||||||
|
|
||||||
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
|
|
||||||
|
|
||||||
assert cloned is not source
|
|
||||||
assert cloned.multigpu_clones[first_device] is existing_clone
|
|
||||||
assert stale_device not in cloned.multigpu_clones
|
|
||||||
assert second_device in cloned.multigpu_clones
|
|
||||||
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
|
|
||||||
assert cloned.multigpu_clones[second_device].device == "cpu"
|
|
||||||
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
|
|
||||||
|
|
||||||
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
|
|
||||||
assert single_gpu_clone is not source
|
|
||||||
assert not hasattr(single_gpu_clone, "multigpu_clones")
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
|
|
||||||
install_fake_comfy_aimdo(monkeypatch)
|
|
||||||
import comfy.sd
|
|
||||||
importlib.reload(comfy.sd)
|
|
||||||
|
|
||||||
class FakeVAE:
|
|
||||||
def __init__(self):
|
|
||||||
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
|
|
||||||
|
|
||||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
|
||||||
vae = FakeVAE()
|
|
||||||
metadata = {"format": "checkpoint"}
|
|
||||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
|
||||||
monkeypatch.setattr(
|
|
||||||
comfy.sd,
|
|
||||||
"load_state_dict_guess_config",
|
|
||||||
lambda *args, **kwargs: (model_patcher, None, vae, None),
|
|
||||||
)
|
|
||||||
|
|
||||||
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
|
|
||||||
|
|
||||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
|
||||||
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
|
|
||||||
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
|
|
||||||
install_fake_comfy_aimdo(monkeypatch)
|
|
||||||
import comfy.sd
|
|
||||||
importlib.reload(comfy.sd)
|
|
||||||
|
|
||||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
|
||||||
placeholder_vae = types.SimpleNamespace()
|
|
||||||
metadata = {"format": "checkpoint"}
|
|
||||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
|
||||||
monkeypatch.setattr(
|
|
||||||
comfy.sd,
|
|
||||||
"load_state_dict_guess_config",
|
|
||||||
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
|
|
||||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
|
||||||
Loading…
Reference in New Issue
Block a user