mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 22:30:19 +08:00
Fix absolute imports
This commit is contained in:
parent
d2acf8a93f
commit
3f88282b6a
@ -1,5 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.ops
|
|
||||||
|
from .. import ops
|
||||||
|
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@ -9,13 +11,13 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
|
||||||
def rms_norm(x, weight, eps=1e-6):
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
if rms_norm_torch is not None:
|
if rms_norm_torch is not None:
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
else:
|
else:
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
return (x * rrms) * ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|||||||
timestep_embedding)
|
timestep_embedding)
|
||||||
|
|
||||||
from .model import Flux
|
from .model import Flux
|
||||||
import comfy.ldm.common_dit
|
from .. import common_dit
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
@ -119,13 +119,13 @@ class ControlNetFlux(Flux):
|
|||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
else:
|
else:
|
||||||
hint = hint * 2.0 - 1.0
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from diffusers.utils.import_utils import is_torch_version
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
import comfy.ldm.common_dit
|
from ...ldm import common_dit
|
||||||
from comfy.ldm.flux.layers import (timestep_embedding)
|
from .layers import timestep_embedding
|
||||||
from comfy.ldm.flux.model import Flux
|
from .model import Flux
|
||||||
|
|
||||||
if is_torch_version(">=", "2.1.0"):
|
if is_torch_version(">=", "2.1.0"):
|
||||||
LayerNorm = nn.LayerNorm
|
LayerNorm = nn.LayerNorm
|
||||||
@ -285,7 +285,7 @@ class InstantXControlNetFlux(Flux):
|
|||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
height_control_image, width_control_image = hint.shape[2:]
|
height_control_image, width_control_image = hint.shape[2:]
|
||||||
num_channels_latents = self.in_channels // 4
|
num_channels_latents = self.in_channels // 4
|
||||||
|
|||||||
@ -1,24 +1,11 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch.utils import checkpoint
|
from ... import ops
|
||||||
|
from ...latent_formats import SDXL
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
|
||||||
Mlp,
|
|
||||||
TimestepEmbedder,
|
|
||||||
PatchEmbed,
|
|
||||||
RMSNorm,
|
|
||||||
)
|
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
|
||||||
from .poolers import AttentionPool
|
|
||||||
|
|
||||||
import comfy.latent_formats
|
|
||||||
from .models import HunYuanDiTBlock, calc_rope
|
from .models import HunYuanDiTBlock, calc_rope
|
||||||
|
from .poolers import AttentionPool
|
||||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
from ..modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed
|
||||||
|
|
||||||
|
|
||||||
class HunYuanControlNet(nn.Module):
|
class HunYuanControlNet(nn.Module):
|
||||||
@ -93,7 +80,7 @@ class HunYuanControlNet(nn.Module):
|
|||||||
self.use_style_cond = use_style_cond
|
self.use_style_cond = use_style_cond
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.latent_format = comfy.latent_formats.SDXL
|
self.latent_format = SDXL
|
||||||
|
|
||||||
self.mlp_t5 = nn.Sequential(
|
self.mlp_t5 = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
@ -261,7 +248,7 @@ class HunYuanControlNet(nn.Module):
|
|||||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||||
|
|
||||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
padding = ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||||
|
|
||||||
text_states[:, -self.text_len :] = torch.where(
|
text_states[:, -self.text_len :] = torch.where(
|
||||||
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||||
|
|||||||
@ -357,7 +357,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
return common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user