diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 9016abc44..0ea6e1bfe 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,7 @@ import torch -import comfy.ops + +from .. import ops + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): @@ -9,13 +11,13 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) try: - rms_norm_torch = torch.nn.functional.rms_norm + rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member except: rms_norm_torch = None def rms_norm(x, weight, eps=1e-6): if rms_norm_torch is not None: - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) else: rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) + return (x * rrms) * ops.cast_to(weight, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 2c658a4b1..8d346398f 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -10,7 +10,7 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer, timestep_embedding) from .model import Flux -import comfy.ldm.common_dit +from .. import common_dit class ControlNetFlux(Flux): @@ -119,13 +119,13 @@ class ControlNetFlux(Flux): def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: - hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) + hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) else: hint = hint * 2.0 - 1.0 bs, c, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/flux/controlnet_instantx.py b/comfy/ldm/flux/controlnet_instantx.py index 5fd1f8026..fffb8656a 100644 --- a/comfy/ldm/flux/controlnet_instantx.py +++ b/comfy/ldm/flux/controlnet_instantx.py @@ -11,9 +11,9 @@ from diffusers.utils.import_utils import is_torch_version from einops import rearrange, repeat from torch import Tensor, nn -import comfy.ldm.common_dit -from comfy.ldm.flux.layers import (timestep_embedding) -from comfy.ldm.flux.model import Flux +from ...ldm import common_dit +from .layers import timestep_embedding +from .model import Flux if is_torch_version(">=", "2.1.0"): LayerNorm = nn.LayerNorm @@ -285,7 +285,7 @@ class InstantXControlNetFlux(Flux): def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs): bs, c, h, w = x.shape patch_size = 2 - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) height_control_image, width_control_image = hint.shape[2:] num_channels_latents = self.in_channels // 4 diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py index cd71fca31..f34f855fd 100644 --- a/comfy/ldm/hydit/controlnet.py +++ b/comfy/ldm/hydit/controlnet.py @@ -1,24 +1,11 @@ -from typing import Any, Optional - import torch import torch.nn as nn -import torch.nn.functional as F -from torch.utils import checkpoint - -from comfy.ldm.modules.diffusionmodules.mmdit import ( - Mlp, - TimestepEmbedder, - PatchEmbed, - RMSNorm, -) -from comfy.ldm.modules.diffusionmodules.util import timestep_embedding -from .poolers import AttentionPool - -import comfy.latent_formats +from ... import ops +from ...latent_formats import SDXL from .models import HunYuanDiTBlock, calc_rope - -from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop +from .poolers import AttentionPool +from ..modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed class HunYuanControlNet(nn.Module): @@ -93,7 +80,7 @@ class HunYuanControlNet(nn.Module): self.use_style_cond = use_style_cond self.norm = norm self.dtype = dtype - self.latent_format = comfy.latent_formats.SDXL + self.latent_format = SDXL self.mlp_t5 = nn.Sequential( nn.Linear( @@ -261,7 +248,7 @@ class HunYuanControlNet(nn.Module): b_t5, l_t5, c_t5 = text_states_t5.shape text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) - padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states) + padding = ops.cast_to_input(self.text_embedding_padding, text_states) text_states[:, -self.text_len :] = torch.where( text_states_mask[:, -self.text_len :].unsqueeze(2), diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index dc4ff87a8..d43ba22fa 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -357,7 +357,7 @@ class RMSNorm(torch.nn.Module): self.register_parameter("weight", None) def forward(self, x): - return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + return common_dit.rms_norm(x, self.weight, self.eps)