Update controlnet.py

Make it compatible with the changes in nodes.py.
Added parameter soft_injection to the class ControlBase
Added control weight modulator
This commit is contained in:
Marco 2023-09-29 01:44:02 +02:00
parent a05005ab9d
commit 6e1a07d301

View File

@ -30,6 +30,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
class ControlBase: class ControlBase:
def __init__(self, device=None): def __init__(self, device=None):
self.scaled_weight = None
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
@ -41,13 +42,17 @@ class ControlBase:
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = False self.global_average_pooling = False
self.soft_injection = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): def set_cond_hint(self, cond_hint, strength=1.0, soft_injection=False, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range
self.soft_injection = soft_injection
return self return self
def set_cond_scaled_weight(self, scaled_weight):
self.scaled_weight = scaled_weight #Added an object to keep the scaled weights
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -127,9 +132,10 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, soft_inject=False, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.soft_injection = soft_inject
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
@ -160,6 +166,8 @@ class ControlNet(ControlBase):
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
if self.soft_injection:
control = [c * scale for c, scale in zip(control, self.scaled_weight)]
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):