Fix CodeRabbit findings in worksplit-multigpu (#14017)

Fix CodeRabbit findings in worksplit-multigpu
This commit is contained in:
Jedrzej Kosinski 2026-05-21 11:42:08 -07:00 committed by GitHub
commit 1417b711ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 59 additions and 25 deletions

View File

@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.") parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@ -357,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function) super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model) self.set_extra_arg("base_model", model.diffusion_model)
def cleanup(self):
self.extra_args.pop("base_model", None)
super().cleanup()
def copy(self): def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model c.control_model = self.control_model

View File

@ -215,17 +215,19 @@ def get_all_torch_devices(exclude_current=False):
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
if is_nvidia(): if is_nvidia():
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
devices.append(torch.device(i)) devices.append(torch.device("cuda", i))
elif is_intel_xpu(): elif is_intel_xpu():
for i in range(torch.xpu.device_count()): for i in range(torch.xpu.device_count()):
devices.append(torch.device(i)) devices.append(torch.device("xpu", i))
elif is_ascend_npu(): elif is_ascend_npu():
for i in range(torch.npu.device_count()): for i in range(torch.npu.device_count()):
devices.append(torch.device(i)) devices.append(torch.device("npu", i))
else: else:
devices.append(get_torch_device()) devices.append(get_torch_device())
if exclude_current: if exclude_current:
devices.remove(get_torch_device()) current = get_torch_device()
if current in devices:
devices.remove(current)
return devices return devices
def get_gpu_device_options(): def get_gpu_device_options():

View File

@ -1361,13 +1361,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self) callback(self)
def prepare_state(self, timestep, model_options, ignore_multigpu=False): def prepare_state(self, timestep, model_options):
ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep, model_options, ignore_multigpu) callback(self, timestep, model_options)
if not ignore_multigpu and "multigpu_clones" in model_options: if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values(): model_options["ignore_multigpu"] = True
p: ModelPatcher try:
p.prepare_state(timestep, model_options, ignore_multigpu=True) for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options)
finally:
model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self): def restore_hook_patches(self):
if self.hook_patches_backup is not None: if self.hook_patches_backup is not None:

View File

@ -162,16 +162,16 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
gpu_options.register(model) gpu_options.register(model)
else: else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count # only keep model clones that don't go 'past' the intended max_gpu count;
# multigpu_models = model.get_additional_models_with_key("multigpu") # this prunes any inherited multigpu clones whose load_device is no longer allowed
# new_multigpu_models = [] # when max_gpus is lowered between runs.
# for m in multigpu_models: allowed_devices = set(limit_extra_devices)
# if m.load_device in limit_extra_devices: allowed_devices.add(model.load_device)
# new_multigpu_models.append(m) multigpu_models = model.get_additional_models_with_key("multigpu")
# model.set_additional_models("multigpu", new_multigpu_models) new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
# persist skip_devices for use in sampling code if len(new_multigpu_models) != len(multigpu_models):
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: model.set_additional_models("multigpu", new_multigpu_models)
# model.model_options["multigpu_skip_devices"] = skip_devices model.match_multigpu_clones()
return model return model

View File

@ -433,7 +433,11 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
for i in range(1, len(to_batch_temp) + 1): for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory: cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount to_batch = batch_amount
break break
@ -866,14 +870,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds): def pre_run_control(model, conds):
s = model.model_sampling s = model.model_sampling
# Per-device model lookup so multigpu control clones get the matching
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
device_models: dict = {}
patcher = getattr(model, "current_patcher", None)
if patcher is not None:
for p in patcher.get_additional_models_with_key("multigpu"):
device_models[p.load_device] = p.model
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values(): for device, device_cnet in x['control'].multigpu_clones.items():
device_cnet.pre_run(model, percent_to_timestep_function) device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []

View File

@ -1688,7 +1688,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
return out return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):

View File

@ -45,6 +45,16 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
class MultiGPUOptionsNode(io.ComfyNode): class MultiGPUOptionsNode(io.ComfyNode):
""" """
Select the relative speed of GPUs in the special case they have significantly different performance from one another. Select the relative speed of GPUs in the special case they have significantly different performance from one another.
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
round-robin via next_available_device(). Before re-enabling this node, wire its
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
which already implements the proportional split) so the input actually affects work
distribution.
""" """
@classmethod @classmethod
@ -68,7 +78,8 @@ class MultiGPUOptionsNode(io.ComfyNode):
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
if not gpu_options: if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone() else:
gpu_options = gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt) gpu_options.add(opt)