mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 23:12:35 +08:00
Update controlnet.py
Make it compatible with the changes in nodes.py. Added parameter soft_injection to the class ControlBase Added control weight modulator
This commit is contained in:
parent
a05005ab9d
commit
6e1a07d301
@ -30,6 +30,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
|
self.scaled_weight = None
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
@ -41,13 +42,17 @@ class ControlBase:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.global_average_pooling = False
|
self.global_average_pooling = False
|
||||||
|
self.soft_injection = False
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, soft_injection=False, timestep_percent_range=(1.0, 0.0)):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
|
self.soft_injection = soft_injection
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def set_cond_scaled_weight(self, scaled_weight):
|
||||||
|
self.scaled_weight = scaled_weight #Added an object to keep the scaled weights
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -127,9 +132,10 @@ class ControlBase:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, soft_inject=False, global_average_pooling=False, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
|
self.soft_injection = soft_inject
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
@ -160,6 +166,8 @@ class ControlNet(ControlBase):
|
|||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(self.control_model.dtype)
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||||
|
if self.soft_injection:
|
||||||
|
control = [c * scale for c, scale in zip(control, self.scaled_weight)]
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user