mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Update controlnet.py
Make it compatible with the changes in nodes.py. Added parameter soft_injection to the class ControlBase Added control weight modulator
This commit is contained in:
parent
a05005ab9d
commit
6e1a07d301
@ -30,6 +30,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.scaled_weight = None
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
@ -41,13 +42,17 @@ class ControlBase:
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = False
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||
self.soft_injection = False
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, soft_injection=False, timestep_percent_range=(1.0, 0.0)):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
self.soft_injection = soft_injection
|
||||
return self
|
||||
|
||||
def set_cond_scaled_weight(self, scaled_weight):
|
||||
self.scaled_weight = scaled_weight #Added an object to keep the scaled weights
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
||||
if self.previous_controlnet is not None:
|
||||
@ -127,9 +132,10 @@ class ControlBase:
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
def __init__(self, control_model, soft_inject=False, global_average_pooling=False, device=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.soft_injection = soft_inject
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
@ -160,6 +166,8 @@ class ControlNet(ControlBase):
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
if self.soft_injection:
|
||||
control = [c * scale for c, scale in zip(control, self.scaled_weight)]
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user