Fix absolute imports

This commit is contained in:
doctorpangloss 2024-08-29 18:38:58 -07:00
parent d2acf8a93f
commit 3f88282b6a
5 changed files with 20 additions and 31 deletions

View File

@ -1,5 +1,7 @@
import torch
import comfy.ops
from .. import ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@ -9,13 +11,13 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
except:
rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
return (x * rrms) * ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@ -10,7 +10,7 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
from .. import common_dit
class ControlNetFlux(Flux):
@ -119,13 +119,13 @@ class ControlNetFlux(Flux):
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
else:
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

View File

@ -11,9 +11,9 @@ from diffusers.utils.import_utils import is_torch_version
from einops import rearrange, repeat
from torch import Tensor, nn
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (timestep_embedding)
from comfy.ldm.flux.model import Flux
from ...ldm import common_dit
from .layers import timestep_embedding
from .model import Flux
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
@ -285,7 +285,7 @@ class InstantXControlNetFlux(Flux):
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
height_control_image, width_control_image = hint.shape[2:]
num_channels_latents = self.in_channels // 4

View File

@ -1,24 +1,11 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from ... import ops
from ...latent_formats import SDXL
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .poolers import AttentionPool
from ..modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed
class HunYuanControlNet(nn.Module):
@ -93,7 +80,7 @@ class HunYuanControlNet(nn.Module):
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.latent_format = SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
@ -261,7 +248,7 @@ class HunYuanControlNet(nn.Module):
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
padding = ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),

View File

@ -357,7 +357,7 @@ class RMSNorm(torch.nn.Module):
self.register_parameter("weight", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
return common_dit.rms_norm(x, self.weight, self.eps)