diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f1cd2caf5..935e5a32a 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -167,7 +167,7 @@ class ControlNet(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index ded8fd198..66f816b9c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -151,6 +151,11 @@ class BaseModel(torch.nn.Module): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) + + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = conds.CONDCrossAttn(cross_attn_cnet) + return out def load_model_weights(self, sd, unet_prefix=""):