diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ea219c7e5..5d7699fba 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -30,6 +30,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): class ControlBase: def __init__(self, device=None): + self.scaled_weight = None self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 @@ -41,13 +42,17 @@ class ControlBase: self.device = device self.previous_controlnet = None self.global_average_pooling = False - - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.soft_injection = False + def set_cond_hint(self, cond_hint, strength=1.0, soft_injection=False, timestep_percent_range=(1.0, 0.0)): self.cond_hint_original = cond_hint self.strength = strength self.timestep_percent_range = timestep_percent_range + self.soft_injection = soft_injection return self + def set_cond_scaled_weight(self, scaled_weight): + self.scaled_weight = scaled_weight #Added an object to keep the scaled weights + def pre_run(self, model, percent_to_timestep_function): self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) if self.previous_controlnet is not None: @@ -127,9 +132,10 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, soft_inject=False, global_average_pooling=False, device=None): super().__init__(device) self.control_model = control_model + self.soft_injection = soft_inject self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling @@ -160,6 +166,8 @@ class ControlNet(ControlBase): if y is not None: y = y.to(self.control_model.dtype) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + if self.soft_injection: + control = [c * scale for c, scale in zip(control, self.scaled_weight)] return self.control_merge(None, control, control_prev, output_dtype) def copy(self):