Update controlnet.py

Make it compatible with the changes in nodes.py.
Added parameter soft_injection to the class ControlBase
Added control weight modulator
This commit is contained in:
Marco 2023-09-29 01:44:02 +02:00
parent a05005ab9d
commit 6e1a07d301

View File

@ -30,6 +30,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
class ControlBase:
def __init__(self, device=None):
self.scaled_weight = None
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
@ -41,13 +42,17 @@ class ControlBase:
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
self.soft_injection = False
def set_cond_hint(self, cond_hint, strength=1.0, soft_injection=False, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
self.soft_injection = soft_injection
return self
def set_cond_scaled_weight(self, scaled_weight):
self.scaled_weight = scaled_weight #Added an object to keep the scaled weights
def pre_run(self, model, percent_to_timestep_function):
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
if self.previous_controlnet is not None:
@ -127,9 +132,10 @@ class ControlBase:
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
def __init__(self, control_model, soft_inject=False, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.soft_injection = soft_inject
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
@ -160,6 +166,8 @@ class ControlNet(ControlBase):
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
if self.soft_injection:
control = [c * scale for c, scale in zip(control, self.scaled_weight)]
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):