mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Fix absolute imports
This commit is contained in:
parent
d2acf8a93f
commit
3f88282b6a
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
from .. import ops
|
||||
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
@ -9,13 +11,13 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight, eps=1e-6):
|
||||
if rms_norm_torch is not None:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
return (x * rrms) * ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -10,7 +10,7 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
from .. import common_dit
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
@ -119,13 +119,13 @@ class ControlNetFlux(Flux):
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
|
||||
@ -11,9 +11,9 @@ from diffusers.utils.import_utils import is_torch_version
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (timestep_embedding)
|
||||
from comfy.ldm.flux.model import Flux
|
||||
from ...ldm import common_dit
|
||||
from .layers import timestep_embedding
|
||||
from .model import Flux
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
@ -285,7 +285,7 @@ class InstantXControlNetFlux(Flux):
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
height_control_image, width_control_image = hint.shape[2:]
|
||||
num_channels_latents = self.in_channels // 4
|
||||
|
||||
@ -1,24 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from ... import ops
|
||||
from ...latent_formats import SDXL
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
from .poolers import AttentionPool
|
||||
from ..modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
@ -93,7 +80,7 @@ class HunYuanControlNet(nn.Module):
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
self.latent_format = comfy.latent_formats.SDXL
|
||||
self.latent_format = SDXL
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
nn.Linear(
|
||||
@ -261,7 +248,7 @@ class HunYuanControlNet(nn.Module):
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
padding = ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:, -self.text_len :] = torch.where(
|
||||
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||
|
||||
@ -357,7 +357,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
return common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user