mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
MultiGPU Work Units For Accelerated Sampling (CORE-184) (#7063)
This commit is contained in:
parent
04879a8113
commit
0a2dd86e78
@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
|||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
|
||||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
|
|||||||
@ -15,13 +15,14 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
|
|||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
import comfy.ldm.qwen_image.controlnet
|
import comfy.ldm.qwen_image.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Union
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.hooks import HookGroup
|
from comfy.hooks import HookGroup
|
||||||
|
|
||||||
@ -64,6 +65,18 @@ class StrengthType(Enum):
|
|||||||
CONSTANT = 1
|
CONSTANT = 1
|
||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
|
class ControlIsolation:
|
||||||
|
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||||
|
def __init__(self, control: ControlBase):
|
||||||
|
self.control = control
|
||||||
|
self.orig_previous_controlnet = control.previous_controlnet
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.control.previous_controlnet = None
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@ -77,7 +90,7 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet: Union[ControlBase, None] = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
@ -85,6 +98,7 @@ class ControlBase:
|
|||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.extra_hooks: HookGroup = None
|
self.extra_hooks: HookGroup = None
|
||||||
self.preprocess_image = lambda a: a
|
self.preprocess_image = lambda a: a
|
||||||
|
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -111,17 +125,38 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
|
for device_cnet in self.multigpu_clones.values():
|
||||||
|
with ControlIsolation(device_cnet):
|
||||||
|
device_cnet.cleanup()
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
|
for device_cnet in self.multigpu_clones.values():
|
||||||
|
out += device_cnet.get_models_only_self()
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_models_only_self(self):
|
||||||
|
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||||
|
with ControlIsolation(self):
|
||||||
|
return self.get_models()
|
||||||
|
|
||||||
|
def get_instance_for_device(self, device):
|
||||||
|
'Returns instance of this Control object intended for selected device.'
|
||||||
|
return self.multigpu_clones.get(device, self)
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
'''
|
||||||
|
Create deep clone of Control object where model(s) is set to other devices.
|
||||||
|
|
||||||
|
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||||
|
'''
|
||||||
|
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
def get_extra_hooks(self):
|
||||||
out = []
|
out = []
|
||||||
if self.extra_hooks is not None:
|
if self.extra_hooks is not None:
|
||||||
@ -130,7 +165,7 @@ class ControlBase:
|
|||||||
out += self.previous_controlnet.get_extra_hooks()
|
out += self.previous_controlnet.get_extra_hooks()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def copy_to(self, c):
|
def copy_to(self, c: ControlBase):
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.control_model = copy.deepcopy(c.control_model)
|
||||||
|
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = super().get_models()
|
out = super().get_models()
|
||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
|
|||||||
super().pre_run(model, percent_to_timestep_function)
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
self.set_extra_arg("base_model", model.diffusion_model)
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.extra_args.pop("base_model", None)
|
||||||
|
super().cleanup()
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
c.control_model = self.control_model
|
c.control_model = self.control_model
|
||||||
@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||||
|
c.device = load_device
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|||||||
@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
|
|||||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||||
|
|
||||||
x = x.movedim(-1, -2)
|
x = x.movedim(-1, -2)
|
||||||
if context.shape[0] >= 2:
|
|
||||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
swap_cfg_halves = context.shape[0] >= 2
|
||||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
|
||||||
|
if swap_cfg_halves:
|
||||||
|
first_half, second_half = context.chunk(2, dim = 0)
|
||||||
|
context = torch.cat([second_half, first_half], dim = 0)
|
||||||
|
|
||||||
main_condition = context
|
main_condition = context
|
||||||
|
|
||||||
t = 1.0 - t
|
t = 1.0 - t
|
||||||
@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
|
|||||||
output = self.final_layer(combined)
|
output = self.final_layer(combined)
|
||||||
output = output.movedim(-2, -1) * (-1.0)
|
output = output.movedim(-2, -1) * (-1.0)
|
||||||
|
|
||||||
if output.shape[0] >= 2:
|
if swap_cfg_halves:
|
||||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
first_half, second_half = output.chunk(2, dim = 0)
|
||||||
return torch.cat([uncond_emb, cond_emb])
|
output = torch.cat([second_half, first_half], dim = 0)
|
||||||
else:
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
import ctypes
|
import ctypes
|
||||||
import threading
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
|
|||||||
|
|
||||||
class TensorFileSlice(NamedTuple):
|
class TensorFileSlice(NamedTuple):
|
||||||
file_ref: object
|
file_ref: object
|
||||||
thread_id: int
|
lock: object
|
||||||
offset: int
|
offset: int
|
||||||
size: int
|
size: int
|
||||||
|
|
||||||
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
file_obj = info.file_ref
|
file_obj = info.file_ref
|
||||||
if (destination.device.type != "cpu"
|
if (destination.device.type != "cpu"
|
||||||
or file_obj is None
|
or file_obj is None
|
||||||
or threading.get_ident() != info.thread_id
|
|
||||||
or destination.numel() * destination.element_size() < info.size
|
or destination.numel() * destination.element_size() < info.size
|
||||||
or tensor.numel() * tensor.element_size() != info.size
|
or tensor.numel() * tensor.element_size() != info.size
|
||||||
or tensor.storage_offset() != 0
|
or tensor.storage_offset() != 0
|
||||||
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
if hostbuf is not None:
|
if hostbuf is not None:
|
||||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
with info.lock:
|
||||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||||
stream=stream_ptr,
|
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||||
device_ptr=device_ptr,
|
stream=stream_ptr,
|
||||||
device=None if destination2 is None else destination2.device.index)
|
device_ptr=device_ptr,
|
||||||
|
device=None if destination2 is None else destination2.device.index)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
buf_type = ctypes.c_ubyte * info.size
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_obj.seek(info.offset)
|
with info.lock:
|
||||||
done = 0
|
file_obj.seek(info.offset)
|
||||||
while done < info.size:
|
done = 0
|
||||||
try:
|
while done < info.size:
|
||||||
n = file_obj.readinto(view[done:])
|
try:
|
||||||
except OSError:
|
n = file_obj.readinto(view[done:])
|
||||||
return False
|
except OSError:
|
||||||
if n <= 0:
|
return False
|
||||||
return False
|
if n <= 0:
|
||||||
done += n
|
return False
|
||||||
|
done += n
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
view.release()
|
view.release()
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
@ -27,13 +28,18 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
import comfy_aimdo.host_buffer
|
import comfy_aimdo.host_buffer
|
||||||
import comfy_aimdo.vram_buffer
|
import comfy_aimdo.vram_buffer
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@ -204,6 +210,107 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
|
def get_all_torch_devices(exclude_current=False):
|
||||||
|
global cpu_state
|
||||||
|
devices = []
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
|
||||||
|
# without the AMD arm, single-GPU ROCm users get an empty list
|
||||||
|
# which silently turns unload_all_models() into a no-op.
|
||||||
|
if is_nvidia() or is_amd():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
devices.append(torch.device("cuda", i))
|
||||||
|
elif is_intel_xpu():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
devices.append(torch.device("xpu", i))
|
||||||
|
elif is_ascend_npu():
|
||||||
|
for i in range(torch.npu.device_count()):
|
||||||
|
devices.append(torch.device("npu", i))
|
||||||
|
elif is_mlu():
|
||||||
|
for i in range(torch.mlu.device_count()):
|
||||||
|
devices.append(torch.device("mlu", i))
|
||||||
|
else:
|
||||||
|
# Fallback for unhandled GPU backends (e.g. DirectML): at least
|
||||||
|
# report the current device so callers like unload_all_models()
|
||||||
|
# do not silently no-op.
|
||||||
|
devices.append(get_torch_device())
|
||||||
|
else:
|
||||||
|
devices.append(get_torch_device())
|
||||||
|
if exclude_current:
|
||||||
|
current = get_torch_device()
|
||||||
|
if current in devices:
|
||||||
|
devices.remove(current)
|
||||||
|
return devices
|
||||||
|
|
||||||
|
def get_gpu_device_options():
|
||||||
|
"""Return list of device option strings for node widgets.
|
||||||
|
|
||||||
|
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||||
|
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||||
|
"""
|
||||||
|
options = ["default", "cpu"]
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if len(devices) > 1:
|
||||||
|
for i in range(len(devices)):
|
||||||
|
options.append(f"gpu:{i}")
|
||||||
|
return options
|
||||||
|
|
||||||
|
def get_gpu_device_options_no_cpu():
|
||||||
|
"""Variant of get_gpu_device_options that omits "cpu".
|
||||||
|
|
||||||
|
Intended for components like the VAE selector where running on CPU
|
||||||
|
is impractical and should not be offered as a choice.
|
||||||
|
"""
|
||||||
|
return [o for o in get_gpu_device_options() if o != "cpu"]
|
||||||
|
|
||||||
|
def resolve_gpu_device_option(option: str):
|
||||||
|
"""Resolve a device option string to a torch.device.
|
||||||
|
|
||||||
|
Returns None for "default" (let the caller use its normal default).
|
||||||
|
Returns torch.device("cpu") for "cpu".
|
||||||
|
For "gpu:N", returns the Nth torch device. Returns None if the
|
||||||
|
index is out of range, the option string is malformed, or
|
||||||
|
unrecognized (callers are expected to log their own context-rich
|
||||||
|
message before falling back to the default device).
|
||||||
|
"""
|
||||||
|
if option is None or option == "default":
|
||||||
|
return None
|
||||||
|
if option == "cpu":
|
||||||
|
return torch.device("cpu")
|
||||||
|
if option.startswith("gpu:"):
|
||||||
|
try:
|
||||||
|
idx = int(option[4:])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if 0 <= idx < len(devices):
|
||||||
|
return devices[idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cuda_device_context(device):
|
||||||
|
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||||
|
|
||||||
|
Used when running operations on a non-default CUDA device so that custom
|
||||||
|
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||||
|
device index. The previous device is restored on exit.
|
||||||
|
|
||||||
|
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||||
|
the current device.
|
||||||
|
"""
|
||||||
|
prev = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev = torch.cuda.current_device()
|
||||||
|
if prev != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if prev is not None:
|
||||||
|
torch.cuda.set_device(prev)
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@ -492,9 +599,13 @@ try:
|
|||||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
try:
|
||||||
|
for device in get_all_torch_devices(exclude_current=True):
|
||||||
|
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
current_loaded_models: list[LoadedModel] = []
|
||||||
current_loaded_models = []
|
|
||||||
|
|
||||||
DIRTY_MMAPS = set()
|
DIRTY_MMAPS = set()
|
||||||
|
|
||||||
@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False):
|
|||||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model: ModelPatcher):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@ -562,7 +673,7 @@ class LoadedModel:
|
|||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self._patcher_finalizer = None
|
self._patcher_finalizer = None
|
||||||
|
|
||||||
def _set_model(self, model):
|
def _set_model(self, model: ModelPatcher):
|
||||||
self._model = weakref.ref(model)
|
self._model = weakref.ref(model)
|
||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
@ -573,6 +684,7 @@ class LoadedModel:
|
|||||||
model = self._parent_model()
|
model = self._parent_model()
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
|
self.device = model.load_device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device)
|
||||||
|
|
||||||
|
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||||
|
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||||
|
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||||
|
additional_models = []
|
||||||
|
if unload_additional_models:
|
||||||
|
additional_models = model.get_nested_additional_models()
|
||||||
|
keep_loaded = []
|
||||||
|
for loaded_model in initial_keep_loaded:
|
||||||
|
if loaded_model.model is not None:
|
||||||
|
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
continue
|
||||||
|
# check additional models if they are a match
|
||||||
|
skip = False
|
||||||
|
for add_model in additional_models:
|
||||||
|
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
keep_loaded.append(loaded_model)
|
||||||
|
if not all_devices:
|
||||||
|
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||||
|
else:
|
||||||
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device, keep_loaded)
|
||||||
|
|
||||||
def debug_memory_summary():
|
def debug_memory_summary():
|
||||||
if is_amd() or is_nvidia():
|
if is_amd() or is_nvidia():
|
||||||
|
|||||||
@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
def create_model_options_clone(orig_model_options: dict):
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
def create_hook_patches_clone(orig_hook_patches):
|
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||||
new_hook_patches = {}
|
new_hook_patches = {}
|
||||||
for hook_ref in orig_hook_patches:
|
for hook_ref in orig_hook_patches:
|
||||||
new_hook_patches[hook_ref] = {}
|
new_hook_patches[hook_ref] = {}
|
||||||
for k in orig_hook_patches[hook_ref]:
|
for k in orig_hook_patches[hook_ref]:
|
||||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||||
|
if copy_tuples:
|
||||||
|
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||||
|
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||||
return new_hook_patches
|
return new_hook_patches
|
||||||
|
|
||||||
def wipe_lowvram_weight(m):
|
def wipe_lowvram_weight(m):
|
||||||
@ -329,7 +332,10 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||||
|
self.is_multigpu_base_clone = False
|
||||||
|
self.clone_base_uuid = uuid.uuid4()
|
||||||
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@ -366,7 +372,8 @@ class ModelPatcher:
|
|||||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||||
aimdo_mem = 0
|
aimdo_mem = 0
|
||||||
if comfy.memory_management.aimdo_enabled:
|
if comfy.memory_management.aimdo_enabled:
|
||||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
|
||||||
|
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
|
||||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||||
|
|
||||||
def get_clone_model_override(self):
|
def get_clone_model_override(self):
|
||||||
@ -380,6 +387,8 @@ class ModelPatcher:
|
|||||||
if self.cached_patcher_init is None:
|
if self.cached_patcher_init is None:
|
||||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
if len(self.cached_patcher_init) > 2:
|
||||||
|
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||||
model_override = temp_model_patcher.get_clone_model_override()
|
model_override = temp_model_patcher.get_clone_model_override()
|
||||||
if model_override is None:
|
if model_override is None:
|
||||||
model_override = self.get_clone_model_override()
|
model_override = self.get_clone_model_override()
|
||||||
@ -438,19 +447,113 @@ class ModelPatcher:
|
|||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
n.cached_patcher_init = self.cached_patcher_init
|
n.cached_patcher_init = self.cached_patcher_init
|
||||||
|
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||||
|
n.clone_base_uuid = self.clone_base_uuid
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||||
|
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||||
|
if self.cached_patcher_init is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
|
||||||
|
"the loader that produced this model does not support multigpu "
|
||||||
|
"(cached_patcher_init is not initialized). Use a core loader "
|
||||||
|
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
|
||||||
|
"or have the custom loader register a cached_patcher_init factory."
|
||||||
|
)
|
||||||
|
comfy.model_management.unload_model_and_clones(self)
|
||||||
|
# Produce a freshly-loaded patcher from the loader factory so the multigpu
|
||||||
|
# clone owns its own untainted model weights (rather than relying on
|
||||||
|
# copy.deepcopy of an already-patched/already-loaded module).
|
||||||
|
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||||
|
if len(self.cached_patcher_init) > 2:
|
||||||
|
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||||
|
# Override clone()'s normal "share self.model + share backup containers" with
|
||||||
|
# the pristine model from temp_model_patcher plus empty backup containers --
|
||||||
|
# the fresh model has no patches applied, so any deepcopy of self's stale
|
||||||
|
# backup/object_patches_backup/pinned would just propagate dead state that
|
||||||
|
# no longer corresponds to anything in n.model.
|
||||||
|
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
|
||||||
|
n = self.clone(model_override=model_override)
|
||||||
|
# clone() copies hook_backup by reference from self; reset since model is pristine.
|
||||||
|
n.hook_backup = {}
|
||||||
|
# set load device, if present
|
||||||
|
if new_load_device is not None:
|
||||||
|
n.load_device = new_load_device
|
||||||
|
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
|
||||||
|
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
|
||||||
|
# __init__ only registered its own (default) load_device.
|
||||||
|
if hasattr(n, "register_load_device"):
|
||||||
|
n.register_load_device(n.load_device)
|
||||||
|
# multigpu clone should not have multigpu additional_models entry
|
||||||
|
n.remove_additional_models("multigpu")
|
||||||
|
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||||
|
if models_cache is None:
|
||||||
|
models_cache = {}
|
||||||
|
for key, model_list in n.additional_models.items():
|
||||||
|
for i in range(len(model_list)):
|
||||||
|
add_model = n.additional_models[key][i]
|
||||||
|
if add_model.clone_base_uuid not in models_cache:
|
||||||
|
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||||
|
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||||
|
callback(self, n)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def match_multigpu_clones(self):
|
||||||
|
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) > 0:
|
||||||
|
new_multigpu_models = []
|
||||||
|
for mm in multigpu_models:
|
||||||
|
# clone main model, but bring over relevant props from existing multigpu clone
|
||||||
|
n = self.clone()
|
||||||
|
n.load_device = mm.load_device
|
||||||
|
n.backup = mm.backup
|
||||||
|
n.object_patches_backup = mm.object_patches_backup
|
||||||
|
n.hook_backup = mm.hook_backup
|
||||||
|
n.model = mm.model
|
||||||
|
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||||
|
n.remove_additional_models("multigpu")
|
||||||
|
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||||
|
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||||
|
# figure out which additional models are not present in multigpu clone
|
||||||
|
models_cache = {}
|
||||||
|
for mm_add_model in mm.get_additional_models():
|
||||||
|
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||||
|
remove_models_uuids = set(list(models_cache.keys()))
|
||||||
|
for key, model_list in orig_additional_models.items():
|
||||||
|
for orig_add_model in model_list:
|
||||||
|
if orig_add_model.clone_base_uuid not in models_cache:
|
||||||
|
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||||
|
existing_list = n.get_additional_models_with_key(key)
|
||||||
|
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||||
|
n.set_additional_models(key, existing_list)
|
||||||
|
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||||
|
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||||
|
# remove duplicate additional models
|
||||||
|
for key, model_list in n.additional_models.items():
|
||||||
|
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||||
|
n.set_additional_models(key, new_model_list)
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||||
|
callback(self, n)
|
||||||
|
new_multigpu_models.append(n)
|
||||||
|
self.set_additional_models("multigpu", new_multigpu_models)
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||||
if not self.is_clone(clone):
|
if allow_multigpu:
|
||||||
return False
|
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if not self.is_clone(clone):
|
||||||
|
return False
|
||||||
|
|
||||||
if self.current_hooks != clone.current_hooks:
|
if self.current_hooks != clone.current_hooks:
|
||||||
return False
|
return False
|
||||||
@ -1232,7 +1335,7 @@ class ModelPatcher:
|
|||||||
return self.additional_models.get(key, [])
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
def get_additional_models(self):
|
def get_additional_models(self):
|
||||||
all_models = []
|
all_models: list[ModelPatcher] = []
|
||||||
for models in self.additional_models.values():
|
for models in self.additional_models.values():
|
||||||
all_models.extend(models)
|
all_models.extend(models)
|
||||||
return all_models
|
return all_models
|
||||||
@ -1286,9 +1389,18 @@ class ModelPatcher:
|
|||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def prepare_state(self, timestep):
|
def prepare_state(self, timestep, model_options):
|
||||||
|
ignore_multigpu = model_options.get("ignore_multigpu", False)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
callback(self, timestep)
|
callback(self, timestep, model_options)
|
||||||
|
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||||
|
model_options["ignore_multigpu"] = True
|
||||||
|
try:
|
||||||
|
for p in model_options["multigpu_clones"].values():
|
||||||
|
p: ModelPatcher
|
||||||
|
p.prepare_state(timestep, model_options)
|
||||||
|
finally:
|
||||||
|
model_options.pop("ignore_multigpu", None)
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if self.hook_patches_backup is not None:
|
if self.hook_patches_backup is not None:
|
||||||
@ -1301,12 +1413,18 @@ class ModelPatcher:
|
|||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
|
multigpu_kf_changed_cache = None
|
||||||
transformer_options = model_options.get("transformer_options", {})
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
|
# cache changed for multigpu usage
|
||||||
|
if "multigpu_clones" in model_options:
|
||||||
|
if multigpu_kf_changed_cache is None:
|
||||||
|
multigpu_kf_changed_cache = []
|
||||||
|
multigpu_kf_changed_cache.append(hook)
|
||||||
# reset current_hooks if contains hook that changed
|
# reset current_hooks if contains hook that changed
|
||||||
if self.current_hooks is not None:
|
if self.current_hooks is not None:
|
||||||
for current_hook in self.current_hooks.hooks:
|
for current_hook in self.current_hooks.hooks:
|
||||||
@ -1318,6 +1436,28 @@ class ModelPatcher:
|
|||||||
self.cached_hook_patches.pop(cached_group)
|
self.cached_hook_patches.pop(cached_group)
|
||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
if "multigpu_clones" in model_options:
|
||||||
|
for p in model_options["multigpu_clones"].values():
|
||||||
|
p: ModelPatcher
|
||||||
|
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||||
|
|
||||||
|
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||||
|
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||||
|
if kf_changed_cache is None:
|
||||||
|
return
|
||||||
|
reset_current_hooks = False
|
||||||
|
# reset current_hooks if contains hook that changed
|
||||||
|
for hook in kf_changed_cache:
|
||||||
|
if self.current_hooks is not None:
|
||||||
|
for current_hook in self.current_hooks.hooks:
|
||||||
|
if current_hook == hook:
|
||||||
|
reset_current_hooks = True
|
||||||
|
break
|
||||||
|
for cached_group in list(self.cached_hook_patches.keys()):
|
||||||
|
if cached_group.contains(hook):
|
||||||
|
self.cached_hook_patches.pop(cached_group)
|
||||||
|
if reset_current_hooks:
|
||||||
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
registered: comfy.hooks.HookGroup = None):
|
registered: comfy.hooks.HookGroup = None):
|
||||||
@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
if not hasattr(self.model, "dynamic_pins"):
|
if not hasattr(self.model, "dynamic_pins"):
|
||||||
self.model.dynamic_pins = {}
|
self.model.dynamic_pins = {}
|
||||||
if self.load_device not in self.model.dynamic_pins:
|
self.register_load_device(self.load_device)
|
||||||
self.model.dynamic_pins[self.load_device] = {
|
self.non_dynamic_delegate_model = None
|
||||||
|
assert load_device is not None
|
||||||
|
|
||||||
|
def register_load_device(self, device):
|
||||||
|
"""Ensure dynamic_pins has an entry for *device*.
|
||||||
|
|
||||||
|
Called from __init__ and also from any code that retargets an
|
||||||
|
already-constructed patcher to a new load_device (e.g. the
|
||||||
|
Select{Model,CLIP,VAE}Device selector nodes); without this entry
|
||||||
|
partially_unload_ram() raises KeyError when it tries to read the
|
||||||
|
per-device pin state.
|
||||||
|
"""
|
||||||
|
if device not in self.model.dynamic_pins:
|
||||||
|
self.model.dynamic_pins[device] = {
|
||||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||||
"hostbufs_initialized": False,
|
"hostbufs_initialized": False,
|
||||||
"failed": False,
|
"failed": False,
|
||||||
"active": False,
|
"active": False,
|
||||||
}
|
}
|
||||||
self.non_dynamic_delegate_model = None
|
|
||||||
assert load_device is not None
|
|
||||||
|
|
||||||
def is_dynamic(self):
|
def is_dynamic(self):
|
||||||
return True
|
return True
|
||||||
|
|||||||
248
comfy/multigpu.py
Normal file
248
comfy/multigpu.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.patcher_extension
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUThreadPool:
|
||||||
|
"""Persistent thread pool for multi-GPU work distribution.
|
||||||
|
|
||||||
|
Maintains one worker thread per extra GPU device. Each thread calls
|
||||||
|
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||||
|
(inductor/triton) stay warm across diffusion steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, devices: list[torch.device]):
|
||||||
|
self._workers: list[threading.Thread] = []
|
||||||
|
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||||
|
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||||
|
|
||||||
|
for device in devices:
|
||||||
|
wq = queue.Queue()
|
||||||
|
rq = queue.Queue()
|
||||||
|
self._work_queues[device] = wq
|
||||||
|
self._result_queues[device] = rq
|
||||||
|
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||||
|
t.start()
|
||||||
|
self._workers.append(t)
|
||||||
|
|
||||||
|
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||||
|
try:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||||
|
while True:
|
||||||
|
item = work_q.get()
|
||||||
|
if item is None:
|
||||||
|
return
|
||||||
|
result_q.put((None, e))
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
item = work_q.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
fn, args, kwargs = item
|
||||||
|
try:
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
result_q.put((result, None))
|
||||||
|
except Exception as e:
|
||||||
|
result_q.put((None, e))
|
||||||
|
|
||||||
|
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||||
|
self._work_queues[device].put((fn, args, kwargs))
|
||||||
|
|
||||||
|
def get_result(self, device: torch.device):
|
||||||
|
return self._result_queues[device].get()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def devices(self) -> list[torch.device]:
|
||||||
|
return list(self._work_queues.keys())
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
for wq in self._work_queues.values():
|
||||||
|
wq.put(None) # sentinel
|
||||||
|
for t in self._workers:
|
||||||
|
t.join(timeout=5.0)
|
||||||
|
|
||||||
|
|
||||||
|
class GPUOptions:
|
||||||
|
def __init__(self, device_index: int, relative_speed: float):
|
||||||
|
self.device_index = device_index
|
||||||
|
self.relative_speed = relative_speed
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return GPUOptions(self.device_index, self.relative_speed)
|
||||||
|
|
||||||
|
def create_dict(self):
|
||||||
|
return {
|
||||||
|
"relative_speed": self.relative_speed
|
||||||
|
}
|
||||||
|
|
||||||
|
class GPUOptionsGroup:
|
||||||
|
def __init__(self):
|
||||||
|
self.options: dict[int, GPUOptions] = {}
|
||||||
|
|
||||||
|
def add(self, info: GPUOptions):
|
||||||
|
self.options[info.device_index] = info
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = GPUOptionsGroup()
|
||||||
|
for opt in self.options.values():
|
||||||
|
c.add(opt)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def register(self, model: ModelPatcher):
|
||||||
|
opts_dict = {}
|
||||||
|
# get devices that are valid for this model
|
||||||
|
devices: list[torch.device] = [model.load_device]
|
||||||
|
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||||
|
extra_model: ModelPatcher
|
||||||
|
devices.append(extra_model.load_device)
|
||||||
|
# create dictionary with actual device mapped to its GPUOptions
|
||||||
|
device_opts_list: list[GPUOptions] = []
|
||||||
|
for device in devices:
|
||||||
|
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||||
|
opts_dict[device] = device_opts.create_dict()
|
||||||
|
device_opts_list.append(device_opts)
|
||||||
|
# make relative_speed relative to 1.0
|
||||||
|
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||||
|
for value in opts_dict.values():
|
||||||
|
value['relative_speed'] /= min_speed
|
||||||
|
model.model_options['multigpu_options'] = opts_dict
|
||||||
|
|
||||||
|
|
||||||
|
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||||
|
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||||
|
model = model.clone()
|
||||||
|
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||||
|
skip_devices = set()
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) > 0:
|
||||||
|
for mm in multigpu_models:
|
||||||
|
skip_devices.add(mm.load_device)
|
||||||
|
skip_devices = list(skip_devices)
|
||||||
|
|
||||||
|
# Exclude the primary model's actual device, not the global current device:
|
||||||
|
# after SelectModelDevice(gpu:N) the primary may not live on the process's
|
||||||
|
# current CUDA device, and excluding the wrong device picks bad extras.
|
||||||
|
all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
|
||||||
|
full_extra_devices = [d for d in all_devices if d != model.load_device]
|
||||||
|
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||||
|
extra_devices = limit_extra_devices.copy()
|
||||||
|
# exclude skipped devices
|
||||||
|
for skip in skip_devices:
|
||||||
|
if skip in extra_devices:
|
||||||
|
extra_devices.remove(skip)
|
||||||
|
# create new deepclones
|
||||||
|
if len(extra_devices) > 0:
|
||||||
|
for device in extra_devices:
|
||||||
|
device_patcher = None
|
||||||
|
if reuse_loaded:
|
||||||
|
# Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
|
||||||
|
# patcher on the same device shares clone_base_uuid but has
|
||||||
|
# is_multigpu_base_clone=False, which would later be filtered out by
|
||||||
|
# prepare_model_patcher_multigpu_clones() and silently shrink the
|
||||||
|
# work split back to one GPU.
|
||||||
|
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||||
|
for lm in loaded_models:
|
||||||
|
if lm.model is None:
|
||||||
|
continue
|
||||||
|
if lm.load_device != device:
|
||||||
|
continue
|
||||||
|
if lm.clone_base_uuid != model.clone_base_uuid:
|
||||||
|
continue
|
||||||
|
if not getattr(lm, "is_multigpu_base_clone", False):
|
||||||
|
continue
|
||||||
|
device_patcher = lm.clone()
|
||||||
|
logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||||
|
break
|
||||||
|
if device_patcher is None:
|
||||||
|
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||||
|
# Always flag the clone; whether reused or freshly deepcloned, it must
|
||||||
|
# advertise itself as a MultiGPU base clone so the cond scheduler picks
|
||||||
|
# it up in prepare_model_patcher_multigpu_clones().
|
||||||
|
device_patcher.is_multigpu_base_clone = True
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
multigpu_models.append(device_patcher)
|
||||||
|
model.set_additional_models("multigpu", multigpu_models)
|
||||||
|
model.match_multigpu_clones()
|
||||||
|
if gpu_options is None:
|
||||||
|
gpu_options = GPUOptionsGroup()
|
||||||
|
gpu_options.register(model)
|
||||||
|
else:
|
||||||
|
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||||
|
# only keep model clones that don't go 'past' the intended max_gpu count;
|
||||||
|
# this prunes any inherited multigpu clones whose load_device is no longer allowed
|
||||||
|
# when max_gpus is lowered between runs.
|
||||||
|
allowed_devices = set(limit_extra_devices)
|
||||||
|
allowed_devices.add(model.load_device)
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
|
||||||
|
if len(new_multigpu_models) != len(multigpu_models):
|
||||||
|
model.set_additional_models("multigpu", new_multigpu_models)
|
||||||
|
model.match_multigpu_clones()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||||
|
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||||
|
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||||
|
opts_dict = model_options['multigpu_options']
|
||||||
|
devices = list(model_options['multigpu_clones'].keys())
|
||||||
|
speed_per_device = []
|
||||||
|
work_per_device = []
|
||||||
|
# get sum of each device's relative_speed
|
||||||
|
total_speed = 0.0
|
||||||
|
for opts in opts_dict.values():
|
||||||
|
total_speed += opts['relative_speed']
|
||||||
|
# get relative work for each device;
|
||||||
|
# obtained by w = (W*r)/R
|
||||||
|
for device in devices:
|
||||||
|
relative_speed = opts_dict[device]['relative_speed']
|
||||||
|
relative_work = (total_work*relative_speed) / total_speed
|
||||||
|
speed_per_device.append(relative_speed)
|
||||||
|
work_per_device.append(relative_work)
|
||||||
|
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||||
|
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||||
|
work_per_device = round_preserved(work_per_device)
|
||||||
|
dict_work_per_device = {}
|
||||||
|
for device, relative_work in zip(devices, work_per_device):
|
||||||
|
dict_work_per_device[device] = relative_work
|
||||||
|
if not return_idle_time:
|
||||||
|
return LoadBalance(dict_work_per_device, None)
|
||||||
|
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||||
|
# time here is relative and does not correspond to real-world units
|
||||||
|
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||||
|
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||||
|
idle_time = abs(min(completion_time) - max(completion_time))
|
||||||
|
# if need to compare work idle time, need to normalize to a common total work
|
||||||
|
if work_normalized:
|
||||||
|
idle_time *= (work_normalized/total_work)
|
||||||
|
|
||||||
|
return LoadBalance(dict_work_per_device, idle_time)
|
||||||
|
|
||||||
|
def round_preserved(values: list[float]):
|
||||||
|
'Round all values in a list, preserving the combined sum of values.'
|
||||||
|
# get floor of values; casting to int does it too
|
||||||
|
floored = [int(x) for x in values]
|
||||||
|
total_floored = sum(floored)
|
||||||
|
# get remainder to distribute
|
||||||
|
remainder = round(sum(values)) - total_floored
|
||||||
|
# pair values with fractional portions
|
||||||
|
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||||
|
# sort by fractional part in descending order
|
||||||
|
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
# distribute the remainder
|
||||||
|
for i in range(remainder):
|
||||||
|
index = fractional[i][0]
|
||||||
|
floored[index] += 1
|
||||||
|
return floored
|
||||||
@ -3,6 +3,8 @@ from typing import Callable
|
|||||||
|
|
||||||
class CallbacksMP:
|
class CallbacksMP:
|
||||||
ON_CLONE = "on_clone"
|
ON_CLONE = "on_clone"
|
||||||
|
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||||
|
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||||
ON_LOAD = "on_load_after"
|
ON_LOAD = "on_load_after"
|
||||||
ON_DETACH = "on_detach_after"
|
ON_DETACH = "on_detach_after"
|
||||||
ON_CLEANUP = "on_cleanup"
|
ON_CLEANUP = "on_cleanup"
|
||||||
|
|||||||
@ -1,16 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
import math
|
import math
|
||||||
import collections
|
import collections
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
@ -119,6 +121,47 @@ def cleanup_additional_models(models):
|
|||||||
if hasattr(m, 'cleanup'):
|
if hasattr(m, 'cleanup'):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||||
|
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||||
|
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) == 0:
|
||||||
|
return
|
||||||
|
extra_devices = [x.load_device for x in multigpu_models]
|
||||||
|
# handle controlnets
|
||||||
|
controlnets: set[ControlBase] = set()
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
if 'control' in kk:
|
||||||
|
controlnets.add(kk['control'])
|
||||||
|
if len(controlnets) > 0:
|
||||||
|
# first, unload all controlnet clones
|
||||||
|
for cnet in list(controlnets):
|
||||||
|
cnet_models = cnet.get_models()
|
||||||
|
for cm in cnet_models:
|
||||||
|
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||||
|
|
||||||
|
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
for device in extra_devices:
|
||||||
|
if device not in curr_cnet.multigpu_clones:
|
||||||
|
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||||
|
curr_cnet = curr_cnet.previous_controlnet
|
||||||
|
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
prev_cnet = curr_cnet.previous_controlnet
|
||||||
|
for device in extra_devices:
|
||||||
|
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||||
|
prev_device_cnet = None
|
||||||
|
if prev_cnet is not None:
|
||||||
|
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||||
|
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||||
|
curr_cnet = prev_cnet
|
||||||
|
# potentially handle gligen - since not widely used, ignored for now
|
||||||
|
|
||||||
def estimate_memory(model, noise_shape, conds):
|
def estimate_memory(model, noise_shape, conds):
|
||||||
cond_shapes = collections.defaultdict(list)
|
cond_shapes = collections.defaultdict(list)
|
||||||
cond_shapes_min = {}
|
cond_shapes_min = {}
|
||||||
@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
|||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
real_model: BaseModel = None
|
model.match_multigpu_clones()
|
||||||
|
preprocess_multigpu_conds(conds, model, model_options)
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
|||||||
memory_required += inference_memory
|
memory_required += inference_memory
|
||||||
minimum_memory_required += inference_memory
|
minimum_memory_required += inference_memory
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model: BaseModel = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|
||||||
@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
|||||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
copy_dict1=False)
|
copy_dict1=False)
|
||||||
return to_load_options
|
return to_load_options
|
||||||
|
|
||||||
|
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||||
|
'''
|
||||||
|
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||||
|
'''
|
||||||
|
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||||
|
if len(multigpu_patchers) > 0:
|
||||||
|
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||||
|
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||||
|
for x in multigpu_patchers:
|
||||||
|
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||||
|
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||||
|
multigpu_dict[x.load_device] = x
|
||||||
|
model_options["multigpu_clones"] = multigpu_dict
|
||||||
|
return multigpu_patchers
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
@ -16,6 +18,7 @@ import comfy.model_patcher
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
|
import comfy.multigpu
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
|
|||||||
|
|
||||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||||
|
|
||||||
def cond_cat(c_list):
|
def cond_cat(c_list, device=None):
|
||||||
temp = {}
|
temp = {}
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
for k in x:
|
for k in x:
|
||||||
@ -153,6 +156,8 @@ def cond_cat(c_list):
|
|||||||
for k in temp:
|
for k in temp:
|
||||||
conds = temp[k]
|
conds = temp[k]
|
||||||
out[k] = conds[0].concat(conds[1:])
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
if device is not None and hasattr(out[k], 'to'):
|
||||||
|
out[k] = out[k].to(device)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
|||||||
)
|
)
|
||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
|
||||||
|
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
|
||||||
|
# aggregation) is duplicated there with per-device scheduling layered on top.
|
||||||
|
if 'multigpu_clones' in model_options:
|
||||||
|
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
if has_default_conds:
|
if has_default_conds:
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep)
|
model.current_patcher.prepare_state(timestep, model_options)
|
||||||
|
|
||||||
# run every hooked_to_run separately
|
# run every hooked_to_run separately
|
||||||
for hooks, to_run in hooked_to_run.items():
|
for hooks, to_run in hooked_to_run.items():
|
||||||
@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
|
|
||||||
return out_conds
|
return out_conds
|
||||||
|
|
||||||
|
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
|
||||||
|
# accumulation, memory-fit batching, and output aggregation, but adds a
|
||||||
|
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
|
||||||
|
# placement, and MultiGPUThreadPool dispatch around the inner loop.
|
||||||
|
out_conds = []
|
||||||
|
out_counts = []
|
||||||
|
# separate conds by matching hooks
|
||||||
|
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||||
|
default_conds = []
|
||||||
|
has_default_conds = False
|
||||||
|
|
||||||
|
output_device = x_in.device
|
||||||
|
|
||||||
|
for i in range(len(conds)):
|
||||||
|
out_conds.append(torch.zeros_like(x_in))
|
||||||
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||||
|
|
||||||
|
cond = conds[i]
|
||||||
|
default_c = []
|
||||||
|
if cond is not None:
|
||||||
|
for x in cond:
|
||||||
|
if 'default' in x:
|
||||||
|
default_c.append(x)
|
||||||
|
has_default_conds = True
|
||||||
|
continue
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
if p.hooks is not None:
|
||||||
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
default_conds.append(default_c)
|
||||||
|
|
||||||
|
if has_default_conds:
|
||||||
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
|
model.current_patcher.prepare_state(timestep, model_options)
|
||||||
|
|
||||||
|
devices = list(model_options['multigpu_clones'].keys())
|
||||||
|
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||||
|
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||||
|
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||||
|
|
||||||
|
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||||
|
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||||
|
|
||||||
|
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||||
|
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||||
|
|
||||||
|
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||||
|
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||||
|
"""
|
||||||
|
for offset in range(len(devices)):
|
||||||
|
i = (start + offset) % len(devices)
|
||||||
|
if device_load[devices[i]] < conds_per_device:
|
||||||
|
return i, devices[i]
|
||||||
|
raise RuntimeError(
|
||||||
|
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||||
|
f"({conds_per_device}) but conds remain to schedule"
|
||||||
|
)
|
||||||
|
|
||||||
|
# run every hooked_to_run separately
|
||||||
|
index_device = 0
|
||||||
|
for hooks, to_run in hooked_to_run.items():
|
||||||
|
while len(to_run) > 0:
|
||||||
|
index_device, current_device = next_available_device(index_device)
|
||||||
|
remaining_capacity = conds_per_device - device_load[current_device]
|
||||||
|
|
||||||
|
first = to_run[0]
|
||||||
|
first_shape = first[0][0].shape
|
||||||
|
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||||
|
to_batch_temp = []
|
||||||
|
for x in range(len(to_run)):
|
||||||
|
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||||
|
to_batch_temp += [x]
|
||||||
|
|
||||||
|
to_batch_temp.reverse()
|
||||||
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
|
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||||
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
|
cond_shapes = collections.defaultdict(list)
|
||||||
|
for tt in batch_amount:
|
||||||
|
for k, v in to_run[tt][0].conditioning.items():
|
||||||
|
cond_shapes[k].append(v.size())
|
||||||
|
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||||
|
to_batch = batch_amount
|
||||||
|
break
|
||||||
|
|
||||||
|
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||||
|
device_load[current_device] += len(conds_to_batch)
|
||||||
|
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||||
|
|
||||||
|
if device_load[current_device] >= conds_per_device:
|
||||||
|
index_device += 1
|
||||||
|
|
||||||
|
class thread_result(NamedTuple):
|
||||||
|
output: Any
|
||||||
|
mult: Any
|
||||||
|
area: Any
|
||||||
|
batch_chunks: int
|
||||||
|
cond_or_uncond: Any
|
||||||
|
error: Exception = None
|
||||||
|
|
||||||
|
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||||
|
try:
|
||||||
|
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||||
|
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||||
|
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||||
|
# run every hooked_to_run separately
|
||||||
|
with torch.no_grad():
|
||||||
|
for hooks, to_batch in batch_tuple:
|
||||||
|
input_x = []
|
||||||
|
mult = []
|
||||||
|
c = []
|
||||||
|
cond_or_uncond = []
|
||||||
|
uuids = []
|
||||||
|
area = []
|
||||||
|
control: ControlBase = None
|
||||||
|
patches = None
|
||||||
|
for x in to_batch:
|
||||||
|
o = x
|
||||||
|
p = o[0]
|
||||||
|
input_x.append(p.input_x)
|
||||||
|
mult.append(p.mult)
|
||||||
|
c.append(p.conditioning)
|
||||||
|
area.append(p.area)
|
||||||
|
cond_or_uncond.append(o[1])
|
||||||
|
uuids.append(p.uuid)
|
||||||
|
control = p.control
|
||||||
|
patches = p.patches
|
||||||
|
|
||||||
|
batch_chunks = len(cond_or_uncond)
|
||||||
|
input_x = torch.cat(input_x).to(device)
|
||||||
|
c = cond_cat(c, device=device)
|
||||||
|
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||||
|
|
||||||
|
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||||
|
model_options['transformer_options'],
|
||||||
|
copy_dict1=False)
|
||||||
|
|
||||||
|
if patches is not None:
|
||||||
|
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||||
|
transformer_options.get("patches", {}),
|
||||||
|
patches
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["uuids"] = uuids[:]
|
||||||
|
transformer_options["sigmas"] = timestep.to(device)
|
||||||
|
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||||
|
transformer_options["multigpu_thread_device"] = device
|
||||||
|
|
||||||
|
cast_transformer_options(transformer_options, device=device)
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
device_control = control.get_instance_for_device(device)
|
||||||
|
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||||
|
|
||||||
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||||
|
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
|
||||||
|
# above are async on CUDA, so the main thread's aggregation
|
||||||
|
# could race with in-flight transfers. CUDA-only QA has not
|
||||||
|
# surfaced this in practice, but before extending multigpu
|
||||||
|
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
|
||||||
|
# here (guarded by `output_device.type == "cuda"`).
|
||||||
|
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||||
|
except Exception as e:
|
||||||
|
results.append(thread_result(None, None, None, None, None, error=e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_batch_pooled(device, batch_tuple):
|
||||||
|
worker_results = []
|
||||||
|
_handle_batch(device, batch_tuple, worker_results)
|
||||||
|
return worker_results
|
||||||
|
|
||||||
|
results: list[thread_result] = []
|
||||||
|
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||||
|
|
||||||
|
# Submit all GPU work to pool threads
|
||||||
|
pool_devices = []
|
||||||
|
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||||
|
if thread_pool is not None:
|
||||||
|
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||||
|
pool_devices.append(device)
|
||||||
|
else:
|
||||||
|
# Fallback: no pool, run everything on main thread
|
||||||
|
_handle_batch(device, batch_tuple, results)
|
||||||
|
|
||||||
|
# Collect results from pool workers
|
||||||
|
for device in pool_devices:
|
||||||
|
worker_results, error = thread_pool.get_result(device)
|
||||||
|
if error is not None:
|
||||||
|
raise error
|
||||||
|
results.extend(worker_results)
|
||||||
|
|
||||||
|
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||||
|
if error is not None:
|
||||||
|
raise error
|
||||||
|
for o in range(batch_chunks):
|
||||||
|
cond_index = cond_or_uncond[o]
|
||||||
|
a = area[o]
|
||||||
|
if a is None:
|
||||||
|
out_conds[cond_index] += output[o] * mult[o]
|
||||||
|
out_counts[cond_index] += mult[o]
|
||||||
|
else:
|
||||||
|
out_c = out_conds[cond_index]
|
||||||
|
out_cts = out_counts[cond_index]
|
||||||
|
dims = len(a) // 2
|
||||||
|
for i in range(dims):
|
||||||
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_c += output[o] * mult[o]
|
||||||
|
out_cts += mult[o]
|
||||||
|
|
||||||
|
for i in range(len(out_conds)):
|
||||||
|
out_conds[i] /= out_counts[i]
|
||||||
|
|
||||||
|
return out_conds
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||||
@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
|
|||||||
|
|
||||||
def pre_run_control(model, conds):
|
def pre_run_control(model, conds):
|
||||||
s = model.model_sampling
|
s = model.model_sampling
|
||||||
|
# Per-device model lookup so multigpu control clones get the matching
|
||||||
|
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
|
||||||
|
device_models: dict = {}
|
||||||
|
patcher = getattr(model, "current_patcher", None)
|
||||||
|
if patcher is not None:
|
||||||
|
for p in patcher.get_additional_models_with_key("multigpu"):
|
||||||
|
device_models[p.load_device] = p.model
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
for device, device_cnet in x['control'].multigpu_clones.items():
|
||||||
|
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
to_load_options = model_options.get("to_load_options", None)
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
if to_load_options is None:
|
if to_load_options is None:
|
||||||
return
|
return
|
||||||
|
cast_transformer_options(to_load_options, device, dtype)
|
||||||
|
|
||||||
|
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||||
casts = []
|
casts = []
|
||||||
if device is not None:
|
if device is not None:
|
||||||
casts.append(device)
|
casts.append(device)
|
||||||
@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
# if nothing to apply, do nothing
|
# if nothing to apply, do nothing
|
||||||
if len(casts) == 0:
|
if len(casts) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# try to call .to on patches
|
# try to call .to on patches
|
||||||
if "patches" in to_load_options:
|
if "patches" in transformer_options:
|
||||||
patches = to_load_options["patches"]
|
patches = transformer_options["patches"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
for cast in casts:
|
for cast in casts:
|
||||||
patch_list[i] = patch_list[i].to(cast)
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
if "patches_replace" in to_load_options:
|
if "patches_replace" in transformer_options:
|
||||||
patches = to_load_options["patches_replace"]
|
patches = transformer_options["patches_replace"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for k in patch_list:
|
for k in patch_list:
|
||||||
@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
# try to call .to on any wrappers/callbacks
|
# try to call .to on any wrappers/callbacks
|
||||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
for wc_name in wrappers_and_callbacks:
|
for wc_name in wrappers_and_callbacks:
|
||||||
if wc_name in to_load_options:
|
if wc_name in transformer_options:
|
||||||
wc: dict[str, list] = to_load_options[wc_name]
|
wc: dict[str, list] = transformer_options[wc_name]
|
||||||
for wc_dict in wc.values():
|
for wc_dict in wc.values():
|
||||||
for wc_list in wc_dict.values():
|
for wc_list in wc_dict.values():
|
||||||
for i in range(len(wc_list)):
|
for i in range(len(wc_list)):
|
||||||
@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
for cast in casts:
|
for cast in casts:
|
||||||
wc_list[i] = wc_list[i].to(cast)
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher: ModelPatcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher = model_patcher
|
||||||
@ -984,16 +1236,32 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
noise = noise.to(device=device, dtype=torch.float32)
|
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
|
||||||
sigmas = sigmas.to(device)
|
|
||||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
|
||||||
|
|
||||||
try:
|
# Create persistent thread pool for all GPU devices (main + extras)
|
||||||
self.model_patcher.pre_run()
|
if multigpu_patchers:
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||||
finally:
|
all_devices = [device] + extra_devices
|
||||||
self.model_patcher.cleanup()
|
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||||
|
|
||||||
|
with comfy.model_management.cuda_device_context(device):
|
||||||
|
try:
|
||||||
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
|
self.model_patcher.pre_run()
|
||||||
|
for multigpu_patcher in multigpu_patchers:
|
||||||
|
multigpu_patcher.pre_run()
|
||||||
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
|
finally:
|
||||||
|
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||||
|
if thread_pool is not None:
|
||||||
|
thread_pool.shutdown()
|
||||||
|
self.model_patcher.cleanup()
|
||||||
|
for multigpu_patcher in multigpu_patchers:
|
||||||
|
multigpu_patcher.cleanup()
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
|
|||||||
381
comfy/sd.py
381
comfy/sd.py
@ -335,41 +335,43 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
self.patcher.patch_hooks(None)
|
self.patcher.patch_hooks(None)
|
||||||
if show_pbar:
|
if show_pbar:
|
||||||
pbar = ProgressBar(len(scheduled_keyframes))
|
pbar = ProgressBar(len(scheduled_keyframes))
|
||||||
|
|
||||||
for scheduled_opts in scheduled_keyframes:
|
with model_management.cuda_device_context(device):
|
||||||
t_range = scheduled_opts[0]
|
for scheduled_opts in scheduled_keyframes:
|
||||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
t_range = scheduled_opts[0]
|
||||||
if "start_percent" in add_dict:
|
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||||
if t_range[1] < add_dict["start_percent"]:
|
if "start_percent" in add_dict:
|
||||||
continue
|
if t_range[1] < add_dict["start_percent"]:
|
||||||
if "end_percent" in add_dict:
|
continue
|
||||||
if t_range[0] > add_dict["end_percent"]:
|
if "end_percent" in add_dict:
|
||||||
continue
|
if t_range[0] > add_dict["end_percent"]:
|
||||||
hooks_keyframes = scheduled_opts[1]
|
continue
|
||||||
for hook, keyframe in hooks_keyframes:
|
hooks_keyframes = scheduled_opts[1]
|
||||||
hook.hook_keyframe._current_keyframe = keyframe
|
for hook, keyframe in hooks_keyframes:
|
||||||
# apply appropriate hooks with values that match new hook_keyframe
|
hook.hook_keyframe._current_keyframe = keyframe
|
||||||
self.patcher.patch_hooks(all_hooks)
|
# apply appropriate hooks with values that match new hook_keyframe
|
||||||
# perform encoding as normal
|
self.patcher.patch_hooks(all_hooks)
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
# perform encoding as normal
|
||||||
cond, pooled = o[:2]
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
pooled_dict = {"pooled_output": pooled}
|
cond, pooled = o[:2]
|
||||||
# add clip_start_percent and clip_end_percent in pooled
|
pooled_dict = {"pooled_output": pooled}
|
||||||
pooled_dict["clip_start_percent"] = t_range[0]
|
# add clip_start_percent and clip_end_percent in pooled
|
||||||
pooled_dict["clip_end_percent"] = t_range[1]
|
pooled_dict["clip_start_percent"] = t_range[0]
|
||||||
# add/update any keys with the provided add_dict
|
pooled_dict["clip_end_percent"] = t_range[1]
|
||||||
pooled_dict.update(add_dict)
|
# add/update any keys with the provided add_dict
|
||||||
# add hooks stored on clip
|
pooled_dict.update(add_dict)
|
||||||
self.add_hooks_to_dict(pooled_dict)
|
# add hooks stored on clip
|
||||||
all_cond_pooled.append([cond, pooled_dict])
|
self.add_hooks_to_dict(pooled_dict)
|
||||||
if show_pbar:
|
all_cond_pooled.append([cond, pooled_dict])
|
||||||
pbar.update(1)
|
if show_pbar:
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
pbar.update(1)
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
return all_cond_pooled
|
return all_cond_pooled
|
||||||
|
|
||||||
@ -383,8 +385,12 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
|
|
||||||
|
with model_management.cuda_device_context(device):
|
||||||
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
out = {"cond": cond, "pooled_output": pooled}
|
out = {"cond": cond, "pooled_output": pooled}
|
||||||
@ -446,9 +452,12 @@ class CLIP:
|
|||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
|
device = self.patcher.load_device
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
|
||||||
|
with model_management.cuda_device_context(device):
|
||||||
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -1026,50 +1035,52 @@ class VAE:
|
|||||||
do_tile = False
|
do_tile = False
|
||||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
samples_in = samples_in[:, :, 0]
|
samples_in = samples_in[:, :, 0]
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
|
||||||
batch_number = int(free_memory / memory_used)
|
|
||||||
batch_number = max(1, batch_number)
|
|
||||||
|
|
||||||
# Pre-allocate output for VAEs that support direct buffer writes
|
with model_management.cuda_device_context(self.device):
|
||||||
preallocated = False
|
try:
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
preallocated = True
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
|
batch_number = int(free_memory / memory_used)
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
# Pre-allocate output for VAEs that support direct buffer writes
|
||||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
preallocated = False
|
||||||
if preallocated:
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
else:
|
preallocated = True
|
||||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
|
||||||
if pixel_samples is None:
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
|
||||||
pixel_samples[x:x+batch_number].copy_(out)
|
|
||||||
del out
|
|
||||||
self.process_output(pixel_samples[x:x+batch_number])
|
|
||||||
except Exception as e:
|
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
|
||||||
|
|
||||||
if do_tile:
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
comfy.model_management.soft_empty_cache()
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
dims = samples_in.ndim - 2
|
if preallocated:
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
else:
|
||||||
elif dims == 2:
|
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
if pixel_samples is None:
|
||||||
elif dims == 3:
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
tile = 256 // self.spacial_compression_decode()
|
pixel_samples[x:x+batch_number].copy_(out)
|
||||||
overlap = tile // 4
|
del out
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
self.process_output(pixel_samples[x:x+batch_number])
|
||||||
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
|
elif dims == 2:
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
elif dims == 3:
|
||||||
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
overlap = tile // 4
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -1087,20 +1098,21 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
output = self.decode_tiled_1d(samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
output = self.decode_tiled_(samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
output = self.decode_tiled_(samples, **args)
|
||||||
if overlap_t is None:
|
elif dims == 3:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
if overlap_t is None:
|
||||||
else:
|
args["overlap"] = (1, overlap, overlap)
|
||||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
else:
|
||||||
if tile_t is not None:
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
args["tile_t"] = max(2, tile_t)
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -1113,44 +1125,46 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
pixel_samples = pixel_samples.unsqueeze(2)
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
with model_management.cuda_device_context(self.device):
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
try:
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
batch_number = max(1, batch_number)
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
samples = None
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
batch_number = max(1, batch_number)
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
samples = None
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
tile = 256
|
||||||
|
overlap = tile // 4
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||||
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
pixels_in = pixels_in.to(self.device)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
out = self.first_stage_model.encode(pixels_in)
|
|
||||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
|
||||||
if samples is None:
|
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
|
||||||
samples[x:x + batch_number] = out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
|
||||||
|
|
||||||
if do_tile:
|
|
||||||
comfy.model_management.soft_empty_cache()
|
|
||||||
if self.latent_dim == 3:
|
|
||||||
tile = 256
|
|
||||||
overlap = tile // 4
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
|
||||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
|
||||||
else:
|
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -1176,26 +1190,27 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1:
|
||||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
if tile_t is not None:
|
elif dims == 3:
|
||||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
if tile_t is not None:
|
||||||
else:
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
tile_t_latent = 9999
|
else:
|
||||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
if overlap_t is None:
|
if overlap_t is None:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
args["overlap"] = (1, overlap, overlap)
|
||||||
else:
|
else:
|
||||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||||
maximum = pixel_samples.shape[2]
|
maximum = pixel_samples.shape[2]
|
||||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -1710,12 +1725,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if output_model and out[0] is not None:
|
if out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||||
if output_clip and out[1] is not None:
|
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
|
||||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
|
||||||
|
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
|
||||||
|
# already-loaded module.
|
||||||
|
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
|
||||||
|
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
|
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
|
||||||
|
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
|
||||||
|
factory for the CLIP returned by load_checkpoint_guess_config."""
|
||||||
|
_, clip, _, _ = load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=False,
|
||||||
|
output_clip=True,
|
||||||
|
output_clipvision=False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic,
|
||||||
|
)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
|
||||||
|
factory for the VAE returned by load_checkpoint_guess_config."""
|
||||||
|
_, _, vae, _ = load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=False,
|
||||||
|
output_clipvision=False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic,
|
||||||
|
)
|
||||||
|
return vae.patcher
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||||
embedding_directory=embedding_directory,
|
embedding_directory=embedding_directory,
|
||||||
@ -1742,7 +1797,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||||
|
|
||||||
custom_operations = model_options.get("custom_operations", None)
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
if custom_operations is None:
|
if custom_operations is None:
|
||||||
@ -1782,13 +1837,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
vae_device = model_options.get("load_device", None)
|
||||||
|
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
if te_model_options.get("custom_operations", None) is None:
|
if te_model_options.get("custom_operations", None) is None:
|
||||||
@ -1872,7 +1929,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
@ -1897,7 +1954,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
else:
|
else:
|
||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.quant_config is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
@ -1939,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
|||||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
|
||||||
|
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
|
||||||
|
|
||||||
|
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
|
||||||
|
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||||
|
fresh, untainted VAE patcher (no inherited per-device load state, no
|
||||||
|
in-place quantization fallout) for multigpu work-units and the
|
||||||
|
SelectVAEDevice node. The optional ``device`` matches the source loader's
|
||||||
|
VAE initialization path; the deepclone's ``load_device`` still controls
|
||||||
|
where the cloned patcher is targeted.
|
||||||
|
"""
|
||||||
|
if metadata is None:
|
||||||
|
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
|
else:
|
||||||
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
|
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
|
return vae.patcher
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|||||||
@ -86,6 +86,7 @@ def load_safetensors(ckpt):
|
|||||||
import comfy_aimdo.model_mmap
|
import comfy_aimdo.model_mmap
|
||||||
|
|
||||||
f = open(ckpt, "rb", buffering=0)
|
f = open(ckpt, "rb", buffering=0)
|
||||||
|
file_lock = threading.Lock()
|
||||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
file_size = os.path.getsize(ckpt)
|
file_size = os.path.getsize(ckpt)
|
||||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
|
|||||||
storage = tensor.untyped_storage()
|
storage = tensor.untyped_storage()
|
||||||
setattr(storage,
|
setattr(storage,
|
||||||
"_comfy_tensor_file_slice",
|
"_comfy_tensor_file_slice",
|
||||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||||
sd[name] = tensor
|
sd[name] = tensor
|
||||||
|
|
||||||
|
|||||||
412
comfy_extras/nodes_multigpu.py
Normal file
412
comfy_extras/nodes_multigpu.py
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from inspect import cleandoc
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.sd import CLIP, VAE
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.multigpu
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Prepares model to have sampling accelerated via splitting work units.
|
||||||
|
|
||||||
|
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||||
|
|
||||||
|
Other than those exceptions, this node can be placed in any order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="MultiGPU_WorkUnits",
|
||||||
|
display_name="MultiGPU CFG Split",
|
||||||
|
category="advanced/multigpu",
|
||||||
|
description=cleandoc(cls.__doc__),
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||||
|
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _force_fp32_cpu_compute(patcher: ModelPatcher):
|
||||||
|
"""Force fp32 inference dtype for CPU.
|
||||||
|
|
||||||
|
PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
|
||||||
|
and run ~500-600x slower than fp32, which makes a normal-sized workflow
|
||||||
|
look frozen for hours. Routing through set_model_compute_dtype leaves the
|
||||||
|
weights as-is and casts at use, so peak memory does not blow up."""
|
||||||
|
dtype = patcher.model_dtype()
|
||||||
|
if dtype in (torch.float16, torch.bfloat16):
|
||||||
|
logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
|
||||||
|
patcher.set_model_compute_dtype(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _remember_base_devices(patcher: ModelPatcher):
|
||||||
|
"""Stash the original load/offload device on the underlying model.
|
||||||
|
|
||||||
|
Stored on patcher.model (which is shared with the input patcher), so
|
||||||
|
later "default" selections can recover the loader's original routing.
|
||||||
|
Only the first Select on a given chain writes these attrs; subsequent
|
||||||
|
deepclones inherit them onto their freshly-loaded model below.
|
||||||
|
"""
|
||||||
|
if not hasattr(patcher.model, "_select_base_load_device"):
|
||||||
|
patcher.model._select_base_load_device = patcher.load_device
|
||||||
|
patcher.model._select_base_offload_device = patcher.offload_device
|
||||||
|
|
||||||
|
|
||||||
|
def _propagate_base_devices(src_model, dst_model):
|
||||||
|
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
|
||||||
|
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
|
||||||
|
dst_model._select_base_load_device = src_model._select_base_load_device
|
||||||
|
dst_model._select_base_offload_device = src_model._select_base_offload_device
|
||||||
|
|
||||||
|
|
||||||
|
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
|
||||||
|
"""Return a patcher whose actual model weights live on *target_load_device*.
|
||||||
|
|
||||||
|
If *patcher* is already on *target_load_device* we just retarget the
|
||||||
|
(already-cloned) patcher's metadata in place. Otherwise we call
|
||||||
|
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
|
||||||
|
the loader's ``cached_patcher_init`` factory -- the only safe way to
|
||||||
|
move weights that may already be partially loaded onto another device.
|
||||||
|
|
||||||
|
NOTE: reusing the input patcher's model when the requested device
|
||||||
|
matches its current load_device is a deliberate fast path. Anything
|
||||||
|
that has already mutated the original model (e.g. a prior KSampler
|
||||||
|
invocation on the same model) will be observed here. This is by
|
||||||
|
design and documented on the SelectXDeviceNode docstrings -- placing
|
||||||
|
Select X Device after a node that consumes the same model is not
|
||||||
|
recommended.
|
||||||
|
"""
|
||||||
|
if patcher.load_device == target_load_device:
|
||||||
|
# Fast path: weights already on the desired device, just update offload.
|
||||||
|
patcher.offload_device = target_offload_device
|
||||||
|
return patcher
|
||||||
|
src_model = patcher.model
|
||||||
|
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
|
||||||
|
patcher.offload_device = target_offload_device
|
||||||
|
_propagate_base_devices(src_model, patcher.model)
|
||||||
|
if hasattr(patcher, "register_load_device"):
|
||||||
|
patcher.register_load_device(patcher.load_device)
|
||||||
|
return patcher
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
||||||
|
"""Resolve the requested device and produce a patcher routed there.
|
||||||
|
|
||||||
|
For "default" we restore the loader's original load/offload pair.
|
||||||
|
For CPU we pin both load and offload to CPU (and, on a dynamic
|
||||||
|
patcher, downgrade to a plain ModelPatcher so the dynamic-only
|
||||||
|
code paths are bypassed).
|
||||||
|
For an explicit GPU we keep the loader's original offload but
|
||||||
|
target the requested load device; if that differs from the current
|
||||||
|
load device the patcher is deepcloned onto the new device.
|
||||||
|
"""
|
||||||
|
_remember_base_devices(patcher)
|
||||||
|
base_load = patcher.model._select_base_load_device
|
||||||
|
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
|
||||||
|
|
||||||
|
if resolved is None:
|
||||||
|
# "default" -> route back to the loader's original devices.
|
||||||
|
return _retarget_patcher(patcher, base_load, base_offload)
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
if patcher.is_dynamic():
|
||||||
|
# clone(disable_dynamic=True) requires cached_patcher_init; let the
|
||||||
|
# exception surface to the caller (Select*DeviceNode.execute), which
|
||||||
|
# will translate it into a passthrough+log so unsupported loaders
|
||||||
|
# don't hard-fail the workflow.
|
||||||
|
patcher = patcher.clone(disable_dynamic=True)
|
||||||
|
patcher.load_device = resolved
|
||||||
|
patcher.offload_device = resolved
|
||||||
|
return patcher
|
||||||
|
return _retarget_patcher(patcher, resolved, base_offload)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
|
||||||
|
"""Drop any multigpu clone whose load_device matches *primary_device*.
|
||||||
|
|
||||||
|
Without pruning, MultiGPU CFG Split would have stacked a clone on
|
||||||
|
the same device the primary now occupies (i.e. the workflow places
|
||||||
|
MultiGPU CFG Split before Select Model Device). Keeps the clone set
|
||||||
|
consistent with the new primary placement.
|
||||||
|
"""
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
if not multigpu_models:
|
||||||
|
return
|
||||||
|
filtered = [m for m in multigpu_models if m.load_device != primary_device]
|
||||||
|
if len(filtered) != len(multigpu_models):
|
||||||
|
logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.")
|
||||||
|
model.set_additional_models("multigpu", filtered)
|
||||||
|
if hasattr(model, "match_multigpu_clones"):
|
||||||
|
model.match_multigpu_clones()
|
||||||
|
|
||||||
|
|
||||||
|
class SelectModelDeviceNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Place the diffusion model on a specific device (default / cpu / gpu:N).
|
||||||
|
|
||||||
|
- "default" restores the device assigned by the loader (even after a
|
||||||
|
prior Select Model Device call).
|
||||||
|
- "cpu" pins both the load and offload device to CPU.
|
||||||
|
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||||
|
device is restored to the loader's original choice.
|
||||||
|
|
||||||
|
When the requested device differs from the device the input model is
|
||||||
|
already on, a fresh model is spawned via the loader's reload factory
|
||||||
|
(cached_patcher_init) so the new patcher owns independent weights on
|
||||||
|
the new device. Loaders that don't support multigpu (no factory) will
|
||||||
|
cause the node to pass through unchanged with a warning.
|
||||||
|
|
||||||
|
If the workflow already has MultiGPU CFG Split applied and the chosen
|
||||||
|
GPU collides with one of the existing multigpu clones, that clone is
|
||||||
|
dropped so two patchers don't end up bound to the same device.
|
||||||
|
|
||||||
|
When the selected device does not exist on the current machine
|
||||||
|
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||||
|
the node passes the model through unchanged and logs a message
|
||||||
|
instead of failing.
|
||||||
|
|
||||||
|
NOTE: Placing Select Model Device *after* a node that has already
|
||||||
|
consumed the same model (e.g. a KSampler that ran on this model on
|
||||||
|
the original device) is not recommended -- any state the prior
|
||||||
|
consumer mutated on the original model will be observed when the
|
||||||
|
selected device matches the original (fast path). Place Select Model
|
||||||
|
Device before any consumer of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SelectModelDevice",
|
||||||
|
display_name="Select Model Device",
|
||||||
|
category="advanced/multigpu",
|
||||||
|
description=cleandoc(cls.__doc__),
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, device="default"):
|
||||||
|
# Allow unknown gpu:N values so portable workflows do not error
|
||||||
|
# at validation time; runtime fallback will handle them.
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
|
||||||
|
model = model.clone()
|
||||||
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
|
if resolved is None and device not in (None, "default"):
|
||||||
|
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
try:
|
||||||
|
model = _apply_patcher_device(model, resolved)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
if resolved is not None:
|
||||||
|
if resolved.type == "cpu":
|
||||||
|
_force_fp32_cpu_compute(model)
|
||||||
|
_prune_multigpu_collision(model, model.load_device)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectCLIPDeviceNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Place the CLIP text encoder on a specific device (default / cpu / gpu:N).
|
||||||
|
|
||||||
|
- "default" restores the device assigned by the loader.
|
||||||
|
- "cpu" pins both the load and offload device to CPU.
|
||||||
|
- "gpu:N" pins the load device to the Nth available GPU.
|
||||||
|
|
||||||
|
When the selected device does not exist on the current machine
|
||||||
|
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||||
|
the node passes the CLIP through unchanged and logs a message
|
||||||
|
instead of failing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SelectCLIPDevice",
|
||||||
|
display_name="Select CLIP Device",
|
||||||
|
category="advanced/multigpu",
|
||||||
|
description=cleandoc(cls.__doc__),
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput:
|
||||||
|
clip = clip.clone()
|
||||||
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
|
if resolved is None and device not in (None, "default"):
|
||||||
|
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
|
||||||
|
return io.NodeOutput(clip)
|
||||||
|
try:
|
||||||
|
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
|
||||||
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectVAEDeviceNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Place the VAE on a specific device (default / gpu:N).
|
||||||
|
|
||||||
|
- "default" restores the device assigned by the loader.
|
||||||
|
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||||
|
device is set to the standard VAE offload device.
|
||||||
|
|
||||||
|
CPU is intentionally not exposed in the UI for the VAE; if a workflow
|
||||||
|
supplies "cpu" anyway (e.g. opened from another machine), the request
|
||||||
|
is dropped with a log message and the VAE is passed through unchanged.
|
||||||
|
|
||||||
|
When the selected device does not exist on the current machine
|
||||||
|
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||||
|
the node passes the VAE through unchanged and logs a message
|
||||||
|
instead of failing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SelectVAEDevice",
|
||||||
|
display_name="Select VAE Device",
|
||||||
|
category="advanced/multigpu",
|
||||||
|
description=cleandoc(cls.__doc__),
|
||||||
|
inputs=[
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Vae.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, device="default"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput:
|
||||||
|
# VAE has no .clone(); shallow-copy the wrapper and clone the patcher
|
||||||
|
# so we can retarget load/offload device without affecting the input VAE.
|
||||||
|
vae = copy.copy(vae)
|
||||||
|
vae.patcher = vae.patcher.clone()
|
||||||
|
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||||
|
if resolved is None and device not in (None, "default"):
|
||||||
|
logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.")
|
||||||
|
return io.NodeOutput(vae)
|
||||||
|
if resolved is not None and resolved.type == "cpu":
|
||||||
|
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
|
||||||
|
return io.NodeOutput(vae)
|
||||||
|
if not hasattr(vae, "_select_base_device"):
|
||||||
|
vae._select_base_device = vae.device
|
||||||
|
try:
|
||||||
|
vae.patcher = _apply_patcher_device(
|
||||||
|
vae.patcher, resolved,
|
||||||
|
base_offload_override=comfy.model_management.vae_offload_device(),
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
|
||||||
|
return io.NodeOutput(vae)
|
||||||
|
# Keep VAE wrapper in sync with whatever model the patcher now owns;
|
||||||
|
# deepclone_multigpu may have produced a fresh first_stage_model.
|
||||||
|
vae.first_stage_model = vae.patcher.model
|
||||||
|
vae.device = vae._select_base_device if resolved is None else resolved
|
||||||
|
return io.NodeOutput(vae)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUOptionsNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||||
|
|
||||||
|
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
|
||||||
|
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
|
||||||
|
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
|
||||||
|
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
|
||||||
|
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
|
||||||
|
round-robin via next_available_device(). Before re-enabling this node, wire its
|
||||||
|
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
|
||||||
|
which already implements the proportional split) so the input actually affects work
|
||||||
|
distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="MultiGPU_Options",
|
||||||
|
display_name="MultiGPU Options",
|
||||||
|
category="advanced/multigpu",
|
||||||
|
description=cleandoc(cls.__doc__),
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("device_index", default=0, min=0, max=64),
|
||||||
|
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
|
||||||
|
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Custom("GPU_OPTIONS").Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
|
||||||
|
if not gpu_options:
|
||||||
|
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||||
|
else:
|
||||||
|
gpu_options = gpu_options.clone()
|
||||||
|
|
||||||
|
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||||
|
gpu_options.add(opt)
|
||||||
|
|
||||||
|
return io.NodeOutput(gpu_options)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
MultiGPUCFGSplitNode,
|
||||||
|
SelectModelDeviceNode,
|
||||||
|
SelectCLIPDeviceNode,
|
||||||
|
SelectVAEDeviceNode,
|
||||||
|
# MultiGPUOptionsNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> MultiGPUExtension:
|
||||||
|
return MultiGPUExtension()
|
||||||
2
main.py
2
main.py
@ -218,7 +218,7 @@ import comfy.model_patcher
|
|||||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()):
|
||||||
if args.verbose == 'DEBUG':
|
if args.verbose == 'DEBUG':
|
||||||
comfy_aimdo.control.set_log_debug()
|
comfy_aimdo.control.set_log_debug()
|
||||||
elif args.verbose == 'CRITICAL':
|
elif args.verbose == 'CRITICAL':
|
||||||
|
|||||||
10
nodes.py
10
nodes.py
@ -795,6 +795,7 @@ class VAELoader:
|
|||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
metadata = None
|
metadata = None
|
||||||
|
vae_path = None
|
||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd = {}
|
||||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
@ -813,6 +814,14 @@ class VAELoader:
|
|||||||
metadata["tae_latent_channels"] = 128
|
metadata["tae_latent_channels"] = 128
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
|
# Register a reload factory on the patcher so multigpu deepclones
|
||||||
|
# (Select VAE Device, future MultiGPU VAE work-units) can produce
|
||||||
|
# per-device clones from the same loader context. Only set when we
|
||||||
|
# actually have a single backing file -- pixel_space and the
|
||||||
|
# image TAESDs (composed from separate encoder/decoder files via
|
||||||
|
# load_taesd) are not addressable by a single vae_path.
|
||||||
|
if vae_path is not None:
|
||||||
|
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
@ -2389,6 +2398,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_lt_audio.py",
|
"nodes_lt_audio.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
|
"nodes_multigpu.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
"nodes_cosmos.py",
|
"nodes_cosmos.py",
|
||||||
"nodes_video.py",
|
"nodes_video.py",
|
||||||
|
|||||||
39
server.py
39
server.py
@ -646,18 +646,37 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def system_stats(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
primary_device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
|
||||||
cpu_device = comfy.model_management.torch.device("cpu")
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
|
||||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||||
required_templates_version = FrontendManager.get_required_templates_version()
|
required_templates_version = FrontendManager.get_required_templates_version()
|
||||||
comfy_package_versions = FrontendManager.get_comfy_package_versions()
|
comfy_package_versions = FrontendManager.get_comfy_package_versions()
|
||||||
|
|
||||||
|
# Report every torch device visible to multigpu, with the primary
|
||||||
|
# device first so existing clients that read devices[0] keep working.
|
||||||
|
torch_devices = comfy.model_management.get_all_torch_devices()
|
||||||
|
if primary_device in torch_devices:
|
||||||
|
torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device]
|
||||||
|
else:
|
||||||
|
torch_devices = [primary_device] + list(torch_devices)
|
||||||
|
|
||||||
|
device_entries = []
|
||||||
|
for d in torch_devices:
|
||||||
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True)
|
||||||
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True)
|
||||||
|
device_entries.append({
|
||||||
|
"name": comfy.model_management.get_torch_device_name(d),
|
||||||
|
"type": d.type,
|
||||||
|
"index": d.index,
|
||||||
|
"vram_total": vram_total,
|
||||||
|
"vram_free": vram_free,
|
||||||
|
"torch_vram_total": torch_vram_total,
|
||||||
|
"torch_vram_free": torch_vram_free,
|
||||||
|
})
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": sys.platform,
|
"os": sys.platform,
|
||||||
@ -673,17 +692,7 @@ class PromptServer():
|
|||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
"argv": sys.argv
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": device_entries
|
||||||
{
|
|
||||||
"name": device_name,
|
|
||||||
"type": device.type,
|
|
||||||
"index": device.index,
|
|
||||||
"vram_total": vram_total,
|
|
||||||
"vram_free": vram_free,
|
|
||||||
"torch_vram_total": torch_vram_total,
|
|
||||||
"torch_vram_free": torch_vram_free,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
return web.json_response(system_stats)
|
return web.json_response(system_stats)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user