mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
76a80a65ea
@ -154,6 +154,15 @@ class FrontendManager:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
if version.startswith("v"):
|
||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
||||
return expected_path
|
||||
|
||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
||||
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict
|
||||
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy.component_model.folder_path_types import FolderNames, SaveImagePathTuple
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.folder_path_types import FolderNames, SaveImagePathTuple
|
||||
|
||||
# Variables
|
||||
base_path: str
|
||||
|
||||
@ -60,7 +60,7 @@ class HeuristicPath(NamedTuple):
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ class StrengthType(Enum):
|
||||
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
@ -73,10 +73,6 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
@ -186,8 +182,8 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, ckpt_name: str = None):
|
||||
super().__init__(device)
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, ckpt_name: str = None):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
@ -244,7 +240,7 @@ class ControlNet(ControlBase):
|
||||
to_concat.append(utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
@ -343,8 +339,8 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): # TODO? model_options
|
||||
ControlBase.__init__(self, device)
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): # TODO? model_options
|
||||
ControlBase.__init__(self)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
@ -744,12 +740,15 @@ def load_controlnet(ckpt_path, model=None, model_options=None):
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
super().__init__()
|
||||
self.t2i_model = t2i_model
|
||||
self.channels_in = channels_in
|
||||
self.control_input = None
|
||||
self.compression_ratio = compression_ratio
|
||||
self.upscale_algorithm = upscale_algorithm
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
|
||||
def scale_image_to(self, width, height):
|
||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||
|
||||
@ -41,6 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
|
||||
inf = torch.finfo(dtype)
|
||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
@ -165,6 +165,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST):
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -182,6 +184,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
@ -1119,7 +1144,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
@ -1148,7 +1172,6 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
@ -1182,7 +1205,6 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
@ -1230,4 +1252,4 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
return x
|
||||
|
||||
@ -175,3 +175,30 @@ class Flux(SD3):
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
|
||||
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
|
||||
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
|
||||
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
|
||||
0.959253732819592, 0.8244560132752793, 0.917259975397747,
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = None #TODO
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
@ -15,9 +15,15 @@ try:
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight, eps=1e-6):
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
return (x * rrms) * ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
541
comfy/ldm/genmo/joint_model/asymm_models_joint.py
Normal file
@ -0,0 +1,541 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
# from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
modulate,
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
# Normalize and modulate
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||
|
||||
return x_modulated
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
# Apply tanh to gate
|
||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||
|
||||
# Normalize and apply gated scaling
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed
|
||||
|
||||
return output
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_x: int,
|
||||
dim_y: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
update_y: bool = True,
|
||||
out_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
self.dim_y = dim_y
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_x // num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
)
|
||||
|
||||
# Input layers.
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
# Project text features to match visual features (dim_y -> dim_x)
|
||||
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
self.proj_y = (
|
||||
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
# assert qkv_x.dtype == torch.bfloat16
|
||||
# qkv_x = all_to_all_collect_tokens(
|
||||
# qkv_x, self.num_heads
|
||||
# ) # (3, B, N, local_h, head_dim)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
# print("ox", x)
|
||||
# print("oy", y)
|
||||
return x, y
|
||||
|
||||
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size_x: int,
|
||||
hidden_size_y: int,
|
||||
num_heads: int,
|
||||
*,
|
||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
|
||||
if self.update_y:
|
||||
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
|
||||
else:
|
||||
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# MLP for text not needed in last block.
|
||||
if self.update_y:
|
||||
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||
self.mlp_y = FeedForward(
|
||||
in_features=hidden_size_y,
|
||||
hidden_size=mlp_hidden_dim_y,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim) tensor of visual tokens
|
||||
c: (B, dim) tensor of conditioned features
|
||||
y: (B, L, dim) tensor of text tokens
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
N = x.size(1)
|
||||
|
||||
c = F.silu(c)
|
||||
mod_x = self.mod_x(c)
|
||||
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||
|
||||
mod_y = self.mod_y(c)
|
||||
if self.update_y:
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
|
||||
)
|
||||
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
|
||||
Ingests text embeddings instead of a label.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size_x=1152,
|
||||
hidden_size_y=1152,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio_x=8.0,
|
||||
mlp_ratio_y=4.0,
|
||||
use_t5: bool = False,
|
||||
t5_feat_dim: int = 4096,
|
||||
t5_token_length: int = 256,
|
||||
learn_sigma=True,
|
||||
patch_embed_bias: bool = True,
|
||||
timestep_mlp_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
timestep_scale: Optional[float] = None,
|
||||
use_extended_posenc: bool = False,
|
||||
posenc_preserve_area: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
image_model=None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.head_dim = (
|
||||
hidden_size_x // num_heads
|
||||
) # Head dimension and count is determined by visual.
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.use_extended_posenc = use_extended_posenc
|
||||
self.posenc_preserve_area = posenc_preserve_area
|
||||
self.use_t5 = use_t5
|
||||
self.t5_token_length = t5_token_length
|
||||
self.t5_feat_dim = t5_feat_dim
|
||||
self.rope_theta = (
|
||||
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||
)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size_x,
|
||||
bias=patch_embed_bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
# Conditionings
|
||||
# Timestep
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.use_t5:
|
||||
# Caption Pooling (T5)
|
||||
self.t5_y_embedder = AttentionPool(
|
||||
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# Dense Embedding Projection (T5)
|
||||
self.t5_yproj = operations.Linear(
|
||||
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Initialize pos_frequencies as an empty parameter.
|
||||
self.pos_frequencies = nn.Parameter(
|
||||
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
assert not self.attend_to_padding
|
||||
|
||||
# for depth 48:
|
||||
# b = 0: AsymmetricJointBlock, update_y=True
|
||||
# b = 1: AsymmetricJointBlock, update_y=True
|
||||
# ...
|
||||
# b = 46: AsymmetricJointBlock, update_y=True
|
||||
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||
blocks = []
|
||||
for b in range(depth):
|
||||
# Joint multi-modal block
|
||||
update_y = b < depth - 1
|
||||
block = AsymmetricJointBlock(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads,
|
||||
mlp_ratio_x=mlp_ratio_x,
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
attend_to_padding=attend_to_padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||
|
||||
Returns:
|
||||
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
t5_feat: torch.Tensor,
|
||||
t5_mask: torch.Tensor,
|
||||
):
|
||||
"""Prepare input and conditioning embeddings."""
|
||||
# Visual patch embeddings with positional encoding.
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
pos = create_position_matrix(
|
||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||
) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(
|
||||
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
|
||||
) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
|
||||
|
||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||
|
||||
c = c_t + t5_y_pool
|
||||
|
||||
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||
|
||||
return x, c, y_feat, rope_cos, rope_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: List[torch.Tensor],
|
||||
attention_mask: List[torch.Tensor],
|
||||
num_tokens=256,
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, **kwargs
|
||||
):
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
"""Forward pass of DiT.
|
||||
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat, y_mask
|
||||
)
|
||||
del y_mask
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
|
||||
return -x
|
||||
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
164
comfy/ldm/genmo/joint_model/layers.py
Normal file
@ -0,0 +1,164 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
frequency_embedding_size: int = 256,
|
||||
*,
|
||||
bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.timestep_scale = timestep_scale
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, out_dtype):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
# keep parameter count and computation constant compared to standard FFN
|
||||
hidden_size = int(2 * hidden_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.flatten = flatten
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert norm_layer is None
|
||||
self.norm = (
|
||||
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _C, T, H, W = x.shape
|
||||
if not self.dynamic_img_pad:
|
||||
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
else:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
if not self.flatten:
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
88
comfy/ldm/genmo/joint_model/rope_mixed.py
Normal file
@ -0,0 +1,88 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def centers(start: float, stop, num, dtype=None, device=None):
|
||||
"""linspace through bin centers.
|
||||
|
||||
Args:
|
||||
start (float): Start of the range.
|
||||
stop (float): End of the range.
|
||||
num (int): Number of points.
|
||||
dtype (torch.dtype): Data type of the points.
|
||||
device (torch.device): Device of the points.
|
||||
|
||||
Returns:
|
||||
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||
"""
|
||||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
# @functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
pW: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
target_area: float = 36864,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
T: int - Temporal dimension
|
||||
pH: int - Height dimension after patchify
|
||||
pW: int - Width dimension after patchify
|
||||
|
||||
Returns:
|
||||
pos: [T * pH * pW, 3] - position matrix
|
||||
"""
|
||||
# Create 1D tensors for each dimension
|
||||
t = torch.arange(T, dtype=dtype)
|
||||
|
||||
# Positionally interpolate to area 36864.
|
||||
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||
# This automatically scales rope positions when the resolution changes.
|
||||
# We use a large target area so the model is more sensitive
|
||||
# to changes in the learned pos_frequencies matrix.
|
||||
scale = math.sqrt(target_area / (pW * pH))
|
||||
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||
|
||||
# Use meshgrid to create 3D grids
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
freqs: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||
|
||||
Args:
|
||||
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||
pos: [N, 3] - position of each token
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
assert freqs.ndim == 3
|
||||
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs_sum)
|
||||
freqs_sin = torch.sin(freqs_sum)
|
||||
return freqs_cos, freqs_sin
|
||||
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
34
comfy/ldm/genmo/joint_model/temporal_rope.py
Normal file
@ -0,0 +1,34 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
|
||||
# Based on Llama3 Implementation.
|
||||
import torch
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||
|
||||
Args:
|
||||
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
|
||||
# Apply rotation
|
||||
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
return out
|
||||
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
102
comfy/ldm/genmo/joint_model/utils.py
Normal file
@ -0,0 +1,102 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
|
||||
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
480
comfy/ldm/genmo/vae/model.py
Normal file
480
comfy/ldm/genmo/vae/model.py
Normal file
@ -0,0 +1,480 @@
|
||||
#original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# import mochi_preview.dit.joint_model.context_parallel as cp
|
||||
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
class GroupNormSpatial(ops.GroupNorm):
|
||||
"""
|
||||
GroupNorm applied per-frame.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||
B, C, T, H, W = x.shape
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
# Run group norm in chunks.
|
||||
output = torch.empty_like(x)
|
||||
for b in range(0, B * T, chunk_size):
|
||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||
|
||||
class PConv3d(ops.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
causal: bool = True,
|
||||
context_parallel: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.causal = causal
|
||||
self.context_parallel = context_parallel
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
pad_front = context_size
|
||||
pad_back = 0
|
||||
else:
|
||||
pad_front = context_size // 2
|
||||
pad_back = context_size - pad_front
|
||||
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Conv1x1(ops.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||
super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
temporal_expansion: int,
|
||||
spatial_expansion: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# When printed, this module should show the temporal and spatial expansion factors.
|
||||
def extra_repr(self):
|
||||
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||
"""
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||
st=self.temporal_expansion,
|
||||
sh=self.spatial_expansion,
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
# cp_rank, _ = cp.get_cp_rank_size()
|
||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
assert all(x.shape)
|
||||
x = x[:, :, self.temporal_expansion - 1 :]
|
||||
assert all(x.shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def norm_fn(
|
||||
in_channels: int,
|
||||
affine: bool = True,
|
||||
):
|
||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block that preserves the spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
padding_mode: str = "replicate",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
assert causal
|
||||
self.stack = nn.Sequential(
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
residual = x
|
||||
x = self.stack(x)
|
||||
x = x + residual
|
||||
del residual
|
||||
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
*,
|
||||
temporal_expansion: int = 2,
|
||||
spatial_expansion: int = 2,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = []
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# Change channels in the final convolution layer.
|
||||
self.proj = Conv1x1(
|
||||
in_channels,
|
||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||
)
|
||||
|
||||
self.d2st = DepthToSpaceTime(
|
||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||
assert has_attention is False #NOTE: if this is ever true add back the attention code.
|
||||
|
||||
attn_block = None #AttentionBlock(channels) if has_attention else None
|
||||
|
||||
return ResBlock(
|
||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||
)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks,
|
||||
*,
|
||||
temporal_reduction=2,
|
||||
spatial_reduction=2,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Downsample block for the VAE encoder.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks.
|
||||
temporal_reduction: Temporal reduction factor.
|
||||
spatial_reduction: Spatial reduction factor.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
|
||||
# Change the channel count in the strided convolution.
|
||||
# This lets the ResBlock have uniform channel count,
|
||||
# as in ConvNeXt.
|
||||
assert in_channels != out_channels
|
||||
layers.append(
|
||||
PConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
padding_mode="replicate",
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(block_fn(out_channels, **block_kwargs))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||
num_freqs = (stop - start) // step
|
||||
assert inputs.ndim == 5
|
||||
C = inputs.size(1)
|
||||
|
||||
# Create Base 2 Fourier features.
|
||||
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||
assert num_freqs == len(freqs)
|
||||
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||
C = inputs.shape[1]
|
||||
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w.
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
inputs,
|
||||
torch.sin(h),
|
||||
torch.cos(h),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Add Fourier features to inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||
"""
|
||||
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int = 3,
|
||||
latent_dim: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
temporal_expansions: Optional[List[int]] = None,
|
||||
spatial_expansions: Optional[List[int]] = None,
|
||||
has_attention: List[bool],
|
||||
output_norm: bool = True,
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = latent_dim
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
self.num_up_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
first_block.append(
|
||||
block_fn(
|
||||
ch[-1],
|
||||
has_attention=has_attention[-1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*first_block))
|
||||
|
||||
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||
|
||||
upsample_block_fn = CausalUpsampleBlock
|
||||
|
||||
for i in range(self.num_up_blocks):
|
||||
block = upsample_block_fn(
|
||||
ch[-i - 1],
|
||||
ch[-i - 2],
|
||||
num_res_blocks=num_res_blocks[-i - 2],
|
||||
has_attention=has_attention[-i - 2],
|
||||
temporal_expansion=temporal_expansions[-i - 1],
|
||||
spatial_expansion=spatial_expansions[-i - 1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
assert not output_norm
|
||||
|
||||
# Last block. Preserve channel count.
|
||||
last_block = []
|
||||
for _ in range(num_res_blocks[0]):
|
||||
last_block.append(
|
||||
block_fn(
|
||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*last_block))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||
|
||||
Returns:
|
||||
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||
T + 1 = (t - 1) * 4.
|
||||
H = h * 16, W = w * 16.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.output_nonlinearity == "silu":
|
||||
x = F.silu(x, inplace=not self.training)
|
||||
else:
|
||||
assert (
|
||||
not self.output_nonlinearity
|
||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = None #TODO once the model releases
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from ..attention import optimized_attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
from .... import ops
|
||||
@ -98,7 +98,7 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
@ -267,8 +267,6 @@ def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
@ -327,9 +325,9 @@ class SelfAttention(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
q, k, v = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
q, k, v, heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
@ -418,6 +416,7 @@ class DismantledBlock(nn.Module):
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -441,6 +440,24 @@ class DismantledBlock(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
@ -467,7 +484,11 @@ class DismantledBlock(nn.Module):
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
@ -528,14 +549,64 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return qkv, qkv2, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
)
|
||||
|
||||
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
if self.x_block_self_attn:
|
||||
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||
attn, _ = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
attn2, _ = optimized_attention(
|
||||
qkv2[0], qkv2[1], qkv2[2],
|
||||
num_heads=self.attn2.num_heads,
|
||||
)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
@ -550,7 +621,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
@ -558,8 +632,8 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@ -571,7 +645,14 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
if x_block.x_block_self_attn:
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
@ -586,8 +667,13 @@ class JointBlock(nn.Module):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
@ -643,7 +729,7 @@ class SelfAttentionContext(nn.Module):
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
@ -702,14 +788,19 @@ class MMDiT(nn.Module):
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
x_block_self_attn: bool = False,
|
||||
x_block_self_attn_layers=None,
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
skip_blocks = False,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
if x_block_self_attn_layers is None:
|
||||
x_block_self_attn_layers = []
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
@ -719,6 +810,7 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
@ -776,26 +868,28 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
if not skip_blocks:
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
@ -858,9 +952,11 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options=None,
|
||||
) -> torch.Tensor:
|
||||
if self.compile_core:
|
||||
return self.forward_core_with_concat_compiled(x, c_mod, context)
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
@ -872,14 +968,25 @@ class MMDiT(nn.Module):
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
@ -897,6 +1004,7 @@ class MMDiT(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
@ -905,6 +1013,8 @@ class MMDiT(nn.Module):
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
if self.context_processor is not None:
|
||||
context = self.context_processor(context)
|
||||
|
||||
@ -918,7 +1028,7 @@ class MMDiT(nn.Module):
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
@ -932,7 +1042,10 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
transformer_options=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||
|
||||
|
||||
@ -315,6 +315,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||
key_map[key_lora] = to
|
||||
|
||||
|
||||
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@ -415,7 +419,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype),)
|
||||
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
patch_type = ""
|
||||
if len(v) == 1:
|
||||
|
||||
@ -34,6 +34,7 @@ from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||
from .ldm.cascade.stage_b import StageB
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.flux import model as flux_model
|
||||
from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||
from .ldm.hydit.models import HunYuanDiT
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@ -108,7 +109,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -258,6 +260,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@ -746,3 +752,18 @@ class Flux(BaseModel):
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=AsymmDiTJoint)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@ -69,6 +69,11 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||
if context_processor in state_dict_keys:
|
||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["x_block_self_attn_layers"] = []
|
||||
for key in state_dict_keys:
|
||||
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||
return unet_config
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
@ -144,6 +149,34 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "mochi_preview"
|
||||
dit_config["depth"] = 48
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["hidden_size_x"] = 3072
|
||||
dit_config["hidden_size_y"] = 1536
|
||||
dit_config["mlp_ratio_x"] = 4.0
|
||||
dit_config["mlp_ratio_y"] = 4.0
|
||||
dit_config["learn_sigma"] = False
|
||||
dit_config["in_channels"] = 12
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["out_bias"] = True
|
||||
dit_config["attn_drop"] = 0.0
|
||||
dit_config["patch_embed_bias"] = True
|
||||
dit_config["posenc_preserve_area"] = True
|
||||
dit_config["timestep_mlp_bias"] = True
|
||||
dit_config["attend_to_padding"] = False
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
dit_config["use_t5"] = True
|
||||
dit_config["t5_feat_dim"] = 4096
|
||||
dit_config["t5_token_length"] = 256
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@ -285,9 +318,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
model_config = supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||
if scaled_fp8_weight is not None:
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
|
||||
@ -747,6 +747,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
|
||||
pass
|
||||
|
||||
if fp8_dtype is not None:
|
||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||
return fp8_dtype
|
||||
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
@ -955,29 +958,21 @@ def force_channels_last():
|
||||
# TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
with model_management_lock:
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
device_supports_cast = True
|
||||
elif tensor.dtype == torch.bfloat16:
|
||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||
device_supports_cast = True
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
non_blocking = device_should_use_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
|
||||
FLASH_ATTENTION_ENABLED = False
|
||||
|
||||
@ -104,6 +104,31 @@ class LowVramPatch:
|
||||
return lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
convert_func = None
|
||||
op_keys = key.rsplit('.', 1)
|
||||
if len(op_keys) < 2:
|
||||
weight = utils.get_attr(model, key)
|
||||
else:
|
||||
op = utils.get_attr(model, op_keys[0])
|
||||
try:
|
||||
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
weight = getattr(op, op_keys[1])
|
||||
if convert_func is not None:
|
||||
weight = utils.get_attr(model, key)
|
||||
|
||||
return weight, set_func, convert_func
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
@ -321,14 +346,16 @@ class ModelPatcher(ModelManageable):
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
bk: torch.nn.Module | None = self.backup.get(k, None)
|
||||
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||
if bk is not None:
|
||||
weight = bk.weight
|
||||
else:
|
||||
weight = model_sd[k]
|
||||
if convert_func is None:
|
||||
convert_func = lambda a, **kwargs: a
|
||||
|
||||
if k in self.patches:
|
||||
p[k] = [weight] + self.patches[k]
|
||||
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||
else:
|
||||
p[k] = (weight,)
|
||||
p[k] = [(weight, convert_func)]
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
@ -344,8 +371,7 @@ class ModelPatcher(ModelManageable):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = utils.get_attr(self.model, key)
|
||||
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
@ -355,12 +381,18 @@ class ModelPatcher(ModelManageable):
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
out_weight = lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
if set_func is None:
|
||||
out_weight = stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
|
||||
@ -280,7 +280,10 @@ class VAEDecode:
|
||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||
|
||||
def decode(self, vae, samples):
|
||||
return (vae.decode(samples["samples"]), )
|
||||
images = vae.decode(samples["samples"])
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
|
||||
class VAEDecodeTiled:
|
||||
@classmethod
|
||||
@ -915,7 +918,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -930,6 +933,8 @@ class CLIPLoader:
|
||||
clip_type = sd.CLIPType.SD3
|
||||
elif type == "stable_audio":
|
||||
clip_type = sd.CLIPType.STABLE_AUDIO
|
||||
elif type == "mochi":
|
||||
clip_type = comfy.sd.CLIPType.MOCHI
|
||||
else:
|
||||
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||
|
||||
@ -1211,10 +1216,10 @@ class LatentUpscale:
|
||||
|
||||
if width == 0:
|
||||
height = max(64, height)
|
||||
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
|
||||
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||
elif height == 0:
|
||||
width = max(64, width)
|
||||
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
|
||||
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||
else:
|
||||
width = max(64, width)
|
||||
height = max(64, height)
|
||||
@ -1236,8 +1241,8 @@ class LatentUpscaleBy:
|
||||
|
||||
def upscale(self, samples, upscale_method, scale_by):
|
||||
s = samples.copy()
|
||||
width = round(samples["samples"].shape[3] * scale_by)
|
||||
height = round(samples["samples"].shape[2] * scale_by)
|
||||
width = round(samples["samples"].shape[-1] * scale_by)
|
||||
height = round(samples["samples"].shape[-2] * scale_by)
|
||||
s["samples"] = utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||
return (s,)
|
||||
|
||||
|
||||
107
comfy/ops.py
107
comfy/ops.py
@ -22,22 +22,13 @@ import torch
|
||||
from . import model_management
|
||||
from .cli_args import args
|
||||
from .execution_context import current_execution_context
|
||||
from .float import stochastic_rounding
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
cast_to = model_management.cast_to # TODO: remove once no more references
|
||||
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
@ -53,12 +44,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
weight = model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
@ -308,19 +299,28 @@ def fp8_linear(self, input):
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
@ -330,7 +330,11 @@ def fp8_linear(self, input):
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input.shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -350,17 +354,68 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, inference_mode: Optional[bool] = None):
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
|
||||
if weight.numel() < input.numel(): # TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
weight = stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, inference_mode: Optional[bool] = None):
|
||||
if inference_mode is None:
|
||||
# todo: check a context here, since this isn't being used by any callers yet
|
||||
inference_mode = current_execution_context().inference_mode
|
||||
if model_management.supports_fp8_compute(load_device):
|
||||
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
fp8_compute = model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
# disable_weight_init seems to interact poorly with some other optimization code
|
||||
return disable_weight_init if inference_mode else skip_init
|
||||
if args.fast and not disable_fast_fp8:
|
||||
if model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
|
||||
return manual_cast
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
@ -370,11 +370,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
last_t = -1
|
||||
for t in ts:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
if t != last_t:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
last_t = t
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
||||
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
||||
if steps == 1:
|
||||
sigma_schedule = [1.0, 0.0]
|
||||
else:
|
||||
if linear_steps is None:
|
||||
linear_steps = steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
||||
quadratic_steps = steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||
const = quadratic_coef * (linear_steps ** 2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||
for i in range(linear_steps, steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
@ -759,6 +783,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "linear_quadratic":
|
||||
sigmas = linear_quadratic_schedule(model_sampling, steps)
|
||||
|
||||
if sigmas is None:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
|
||||
63
comfy/sd.py
63
comfy/sd.py
@ -23,6 +23,7 @@ from . import utils
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.genmo.vae.model import VideoVAE
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .model_management import load_models_gpu
|
||||
from .t2i_adapter import adapter
|
||||
@ -34,6 +35,7 @@ from .text_encoders import long_clipl
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import genmo
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
@ -253,6 +255,13 @@ class VAE:
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
self.first_stage_model = VideoVAE()
|
||||
self.latent_channels = 12
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -308,6 +317,10 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -326,6 +339,7 @@ class VAE:
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@ -333,16 +347,21 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
if len(samples_in.shape) == 3:
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
else:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
|
||||
return pixel_samples
|
||||
@ -412,6 +431,7 @@ class CLIPType(Enum):
|
||||
STABLE_AUDIO = 4
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -458,16 +478,14 @@ def detect_te_model(sd):
|
||||
return None
|
||||
|
||||
|
||||
def t5xxl_weight_dtype(clip_data):
|
||||
def t5xxl_detect(clip_data):
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
|
||||
dtype_t5 = None
|
||||
for sd in clip_data:
|
||||
weight = sd.get(weight_name, None)
|
||||
if weight is not None:
|
||||
dtype_t5 = weight.dtype
|
||||
break
|
||||
return dtype_t5
|
||||
if weight_name in sd:
|
||||
return sd3_clip.t5_xxl_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, textmodel_json_config=None):
|
||||
@ -501,8 +519,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif te_model == TEModel.T5_XXL:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data))
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = genmo.MochiT5Tokenizer
|
||||
elif te_model == TEModel.T5_XL:
|
||||
clip_target.clip = aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = aura_t5.AuraT5Tokenizer
|
||||
@ -519,19 +541,19 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data))
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = hydit.HyditModel
|
||||
clip_target.tokenizer = hydit.HyditTokenizer
|
||||
elif clip_type == CLIPType.FLUX:
|
||||
clip_target.clip = flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
|
||||
clip_target.clip = flux.flux_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = flux.FluxTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
|
||||
clip_target.clip = sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
|
||||
parameters = 0
|
||||
@ -621,7 +643,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
@ -691,6 +713,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
||||
sd = temp_sd
|
||||
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
weight_dtype = utils.weight_dtype(sd)
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
@ -717,8 +740,12 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
|
||||
@ -112,11 +112,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
|
||||
if operations is None:
|
||||
operations = ops.manual_cast
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
|
||||
@ -10,6 +10,7 @@ from .text_encoders import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import flux
|
||||
from .text_encoders import genmo
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -529,12 +530,11 @@ class SD3(supported_models_base.BASE):
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
if "dtype_t5" in t5_detect:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -653,13 +653,8 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
dtype_t5 = None
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
else:
|
||||
dtype_t5 = None
|
||||
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(dtype_t5=dtype_t5))
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(**t5_detect))
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
@ -676,7 +671,36 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Mochi
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.GenmoMochi(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -49,6 +49,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -3,19 +3,11 @@ import copy
|
||||
import torch
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
from .t5 import T5
|
||||
from .sd3_clip import T5XXLModel
|
||||
from .. import sd1_clip, model_management
|
||||
from ..component_model import files
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
@ -83,11 +75,14 @@ class FluxClipModel(torch.nn.Module):
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
return FluxClipModel_
|
||||
|
||||
38
comfy/text_encoders/genmo.py
Normal file
38
comfy/text_encoders/genmo.py
Normal file
@ -0,0 +1,38 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["attention_mask"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
@ -15,9 +15,26 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
|
||||
|
||||
|
||||
def t5_xxl_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
||||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
return out
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
@ -154,10 +171,13 @@ class SD3ClipModel(torch.nn.Module):
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False):
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
||||
@ -128,7 +128,7 @@ def weight_dtype(sd, prefix=""):
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
@ -769,9 +769,14 @@ def lanczos(samples, width, height):
|
||||
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
orig_shape = tuple(samples.shape)
|
||||
if len(orig_shape) > 4:
|
||||
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||
samples = samples.movedim(2, 1)
|
||||
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_width = samples.shape[-1]
|
||||
old_height = samples.shape[-2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
@ -780,16 +785,22 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:, :, y:old_height - y, x:old_width - x]
|
||||
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
out = bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
out = lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if len(orig_shape) == 4:
|
||||
return out
|
||||
|
||||
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
@ -801,7 +812,27 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
upscale_amount = [upscale_amount] * dims
|
||||
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return up * val
|
||||
|
||||
def mult_list_upscale(a):
|
||||
out = []
|
||||
for i in range(len(a)):
|
||||
out.append(round(get_upscale(i, a[i])))
|
||||
return out
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b + 1]
|
||||
@ -812,27 +843,27 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d + 2], tile[d] - overlap) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
|
||||
positions = [range(0, s.shape[d + 2], tile[d] - overlap[d]) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
for it in itertools.product(*positions):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
upscaled.append(round(get_upscale(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, bJ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bK as script$1, A as createBaseVNode, x as createBlock, M as Fragment, N as renderList, ak as toDisplayString, an as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-CFrRuGBA.js";
|
||||
import { s as script, a as script$2, b as script$3 } from "./index-CN90wNx3.js";
|
||||
import "./index-CaUteDIK.js";
|
||||
import { d as defineComponent, bQ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bR as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a4 as toDisplayString, au as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-BNX_XOqh.js";
|
||||
import { s as script, a as script$2, b as script$3 } from "./index-B_uZlOM8.js";
|
||||
import "./index-nMMCMbCV.js";
|
||||
const _hoisted_1 = { class: "extension-panel" };
|
||||
const _hoisted_2 = { class: "mt-4" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
@ -100,4 +100,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-DSQ2O8Z9.js.map
|
||||
//# sourceMappingURL=ExtensionPanel-BNXC3_Y5.js.map
|
||||
@ -1 +1 @@
|
||||
{"version":3,"file":"ExtensionPanel-DSQ2O8Z9.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
{"version":3,"file":"ExtensionPanel-BNXC3_Y5.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
1321
comfy/web/assets/GraphView-vodtEF4p.js → comfy/web/assets/GraphView-BumSctau.js
generated
vendored
1321
comfy/web/assets/GraphView-vodtEF4p.js → comfy/web/assets/GraphView-BumSctau.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/GraphView-BumSctau.js.map
generated
vendored
Normal file
1
comfy/web/assets/GraphView-BumSctau.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
61
comfy/web/assets/GraphView-DCFfls4D.css → comfy/web/assets/GraphView-DI5ePCmV.css
generated
vendored
61
comfy/web/assets/GraphView-DCFfls4D.css → comfy/web/assets/GraphView-DI5ePCmV.css
generated
vendored
@ -1,13 +1,13 @@
|
||||
|
||||
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
|
||||
.group-title-editor.node-title-editor[data-v-8a100d5a] {
|
||||
z-index: 9999;
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-fc3f26e3] .editable-text {
|
||||
[data-v-8a100d5a] .editable-text {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
[data-v-fc3f26e3] .editable-text input {
|
||||
[data-v-8a100d5a] .editable-text input {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
/* Override the default font size */
|
||||
@ -45,7 +45,7 @@
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-aa14277f] {
|
||||
.side-tool-bar-container[data-v-37fd2fa4] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
@ -58,40 +58,39 @@
|
||||
background-color: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
}
|
||||
.side-tool-bar-end[data-v-aa14277f] {
|
||||
.side-tool-bar-end[data-v-37fd2fa4] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
.sidebar-content-container[data-v-aa14277f] {
|
||||
height: 100%;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.p-splitter-gutter {
|
||||
[data-v-b49f20b1] .p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.gutter-hidden {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.side-bar-panel[data-v-b9df3042] {
|
||||
.side-bar-panel[data-v-b49f20b1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-b9df3042] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
background-color: transparent;
|
||||
.bottom-panel[data-v-b49f20b1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-b49f20b1] {
|
||||
pointer-events: none;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
}
|
||||
.splitter-overlay-root[data-v-b49f20b1] {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
|
||||
/* Set it the same as the ComfyUI menu */
|
||||
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||
widgets */
|
||||
z-index: 999;
|
||||
border: none;
|
||||
}
|
||||
|
||||
[data-v-37f672ab] .highlight {
|
||||
@ -175,6 +174,14 @@
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.comfy-menu-hamburger[data-v-eac4cb71] {
|
||||
pointer-events: auto;
|
||||
position: fixed;
|
||||
top: 0px;
|
||||
right: 0px;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
[data-v-84e785b8] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
@ -255,7 +262,7 @@
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-b13fdc92] {
|
||||
.comfyui-menu[data-v-221bd572] {
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
@ -267,13 +274,13 @@
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-b13fdc92] {
|
||||
.comfyui-menu.dropzone[data-v-221bd572] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
|
||||
.comfyui-menu.dropzone-active[data-v-221bd572] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
.comfyui-logo[data-v-b13fdc92] {
|
||||
.comfyui-logo[data-v-221bd572] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
1
comfy/web/assets/GraphView-vodtEF4p.js.map
generated
vendored
1
comfy/web/assets/GraphView-vodtEF4p.js.map
generated
vendored
File diff suppressed because one or more lines are too long
@ -1,8 +1,8 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, M as Fragment, N as renderList, i as createVNode, y as withCtx, an as createTextVNode, ak as toDisplayString, z as unref, ar as script, j as createCommentVNode, r as ref, bG as FilterMatchMode, K as useKeybindingStore, F as useCommandStore, aA as watchEffect, aY as useToast, t as resolveDirective, bH as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ad as script$4, b1 as withModifiers, ay as script$6, v as withDirectives, bx as KeyComboImpl, bI as KeybindingImpl, _ as _export_sfc } from "./index-CFrRuGBA.js";
|
||||
import { s as script$1, a as script$3, b as script$5 } from "./index-CN90wNx3.js";
|
||||
import "./index-CaUteDIK.js";
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, au as createTextVNode, a4 as toDisplayString, z as unref, ay as script, j as createCommentVNode, r as ref, bN as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aH as watchEffect, b8 as useToast, t as resolveDirective, bO as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, am as script$4, bd as withModifiers, aF as script$6, v as withDirectives, bJ as KeyComboImpl, bP as KeybindingImpl, _ as _export_sfc } from "./index-BNX_XOqh.js";
|
||||
import { s as script$1, a as script$3, b as script$5 } from "./index-B_uZlOM8.js";
|
||||
import "./index-nMMCMbCV.js";
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
@ -260,4 +260,4 @@ const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "d
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-Cwwh2R-I.js.map
|
||||
//# sourceMappingURL=KeybindingPanel-DU4DXDix.js.map
|
||||
File diff suppressed because one or more lines are too long
64629
comfy/web/assets/index-CFrRuGBA.js → comfy/web/assets/index-BNX_XOqh.js
generated
vendored
64629
comfy/web/assets/index-CFrRuGBA.js → comfy/web/assets/index-BNX_XOqh.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-BNX_XOqh.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-BNX_XOqh.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
6
comfy/web/assets/index-CN90wNx3.js → comfy/web/assets/index-B_uZlOM8.js
generated
vendored
6
comfy/web/assets/index-CN90wNx3.js → comfy/web/assets/index-B_uZlOM8.js
generated
vendored
@ -1,7 +1,7 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { aQ as script$s, g as openBlock, h as createElementBlock, m as mergeProps, A as createBaseVNode, B as BaseStyle, P as script$t, ak as toDisplayString, aj as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, O as resolveDynamicComponent, bL as script$u, l as resolveComponent, C as normalizeClass, am as createSlots, y as withCtx, be as script$v, b4 as script$w, M as Fragment, N as renderList, an as createTextVNode, aW as setAttribute, aU as normalizeProps, p as renderSlot, j as createCommentVNode, aa as equals, aR as script$x, bM as script$y, bN as getFirstFocusableElement, a3 as OverlayEventBus, W as getVNodeProp, a1 as resolveFieldData, bO as invokeElementMethod, bP as getAttribute, bQ as getNextElementSibling, U as getOuterWidth, bR as getPreviousElementSibling, D as script$z, ag as script$A, bS as script$B, aT as script$D, Z as isNotEmpty, b1 as withModifiers, V as getOuterHeight, $ as UniqueComponentId, bT as _default, a0 as ZIndex, a2 as focus, a5 as addStyle, a7 as absolutePosition, a8 as ConnectedOverlayScrollHandler, a9 as isTouchDevice, bU as FilterOperator, af as script$E, bV as FocusTrap, i as createVNode, al as Transition, bW as withKeys, bX as getIndex, s as script$G, bY as isClickable, bZ as clearSelection, b_ as localeComparator, b$ as sort, c0 as FilterService, bG as FilterMatchMode, ac as findSingle, c1 as findIndexInList, c2 as find, c3 as exportCSV, c4 as getOffset, c5 as getHiddenElementOuterWidth, c6 as getHiddenElementOuterHeight, c7 as reorderArray, c8 as removeClass, c9 as addClass, a4 as isEmpty, ae as script$H, ah as script$I, Y as vShow } from "./index-CFrRuGBA.js";
|
||||
import { s as script$C, a as script$F } from "./index-CaUteDIK.js";
|
||||
import { b0 as script$s, g as openBlock, h as createElementBlock, m as mergeProps, A as createBaseVNode, B as BaseStyle, P as script$t, a4 as toDisplayString, $ as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, J as resolveDynamicComponent, bS as script$u, l as resolveComponent, C as normalizeClass, at as createSlots, y as withCtx, bq as script$v, bg as script$w, N as Fragment, O as renderList, au as createTextVNode, b6 as setAttribute, b4 as normalizeProps, p as renderSlot, j as createCommentVNode, a2 as equals, b1 as script$x, bT as script$y, bU as getFirstFocusableElement, ae as OverlayEventBus, a6 as getVNodeProp, ad as resolveFieldData, bV as invokeElementMethod, a0 as getAttribute, bW as getNextElementSibling, W as getOuterWidth, bX as getPreviousElementSibling, D as script$z, ap as script$A, Z as script$B, b3 as script$D, aa as isNotEmpty, bd as withModifiers, U as getOuterHeight, ab as UniqueComponentId, bY as _default, ac as ZIndex, a1 as focus, ag as addStyle, ai as absolutePosition, aj as ConnectedOverlayScrollHandler, ak as isTouchDevice, bZ as FilterOperator, ao as script$E, b_ as FocusTrap, i as createVNode, as as Transition, b$ as withKeys, c0 as getIndex, s as script$G, c1 as isClickable, c2 as clearSelection, c3 as localeComparator, c4 as sort, c5 as FilterService, bN as FilterMatchMode, R as findSingle, c6 as findIndexInList, c7 as find, c8 as exportCSV, V as getOffset, c9 as getHiddenElementOuterWidth, ca as getHiddenElementOuterHeight, cb as reorderArray, cc as removeClass, cd as addClass, af as isEmpty, an as script$H, aq as script$I, a9 as vShow } from "./index-BNX_XOqh.js";
|
||||
import { s as script$C, a as script$F } from "./index-nMMCMbCV.js";
|
||||
var script$r = {
|
||||
name: "ArrowDownIcon",
|
||||
"extends": script$s
|
||||
@ -8930,4 +8930,4 @@ export {
|
||||
script as b,
|
||||
script$2 as s
|
||||
};
|
||||
//# sourceMappingURL=index-CN90wNx3.js.map
|
||||
//# sourceMappingURL=index-B_uZlOM8.js.map
|
||||
2
comfy/web/assets/index-CN90wNx3.js.map → comfy/web/assets/index-B_uZlOM8.js.map
generated
vendored
2
comfy/web/assets/index-CN90wNx3.js.map → comfy/web/assets/index-B_uZlOM8.js.map
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-CFrRuGBA.js.map
generated
vendored
1
comfy/web/assets/index-CFrRuGBA.js.map
generated
vendored
File diff suppressed because one or more lines are too long
221
comfy/web/assets/index-DTOGNau5.js → comfy/web/assets/index-DNRGG-ix.js
generated
vendored
221
comfy/web/assets/index-DTOGNau5.js → comfy/web/assets/index-DNRGG-ix.js
generated
vendored
@ -1,7 +1,7 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { bt as ComfyDialog, bu as $el, bv as ComfyApp, c as app, k as LiteGraph, aN as LGraphCanvas, bw as DraggableList, aZ as useToastStore, av as useNodeDefStore, bp as api, L as LGraphGroup, bx as KeyComboImpl, K as useKeybindingStore, F as useCommandStore, e as LGraphNode, by as ComfyWidgets, bz as applyTextReplacements, at as NodeSourceType, bA as NodeBadgeMode, u as useSettingStore, q as computed, bB as getColorPalette, w as watch, bC as BadgePosition, aP as LGraphBadge, bD as _, bE as defaultColorPalette } from "./index-CFrRuGBA.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-B4bHTYzE.js";
|
||||
import { bF as ComfyDialog, bG as $el, bH as ComfyApp, c as app, k as LiteGraph, a_ as LGraphCanvas, bI as DraggableList, b9 as useToastStore, aC as useNodeDefStore, bB as api, L as LGraphGroup, bJ as KeyComboImpl, M as useKeybindingStore, F as useCommandStore, e as LGraphNode, bK as ComfyWidgets, bL as applyTextReplacements } from "./index-BNX_XOqh.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-B62b6cfl.js";
|
||||
class ClipspaceDialog extends ComfyDialog {
|
||||
static {
|
||||
__name(this, "ClipspaceDialog");
|
||||
@ -37,7 +37,9 @@ class ClipspaceDialog extends ComfyDialog {
|
||||
...self.createButtons()
|
||||
]);
|
||||
if (self.element) {
|
||||
self.element.removeChild(self.element.firstChild);
|
||||
if (self.element.firstChild) {
|
||||
self.element.removeChild(self.element.firstChild);
|
||||
}
|
||||
self.element.appendChild(children);
|
||||
} else {
|
||||
self.element = $el("div.comfy-modal", { parent: document.body }, [
|
||||
@ -76,7 +78,7 @@ class ClipspaceDialog extends ComfyDialog {
|
||||
return buttons;
|
||||
}
|
||||
createImgSettings() {
|
||||
if (ComfyApp.clipspace.imgs) {
|
||||
if (ComfyApp.clipspace?.imgs) {
|
||||
const combo_items = [];
|
||||
const imgs = ComfyApp.clipspace.imgs;
|
||||
for (let i = 0; i < imgs.length; i++) {
|
||||
@ -87,8 +89,10 @@ class ClipspaceDialog extends ComfyDialog {
|
||||
{
|
||||
id: "clipspace_img_selector",
|
||||
onchange: /* @__PURE__ */ __name((event) => {
|
||||
ComfyApp.clipspace["selectedIndex"] = event.target.selectedIndex;
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
if (event.target && ComfyApp.clipspace) {
|
||||
ComfyApp.clipspace["selectedIndex"] = event.target.selectedIndex;
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
}, "onchange")
|
||||
},
|
||||
combo_items
|
||||
@ -102,7 +106,9 @@ class ClipspaceDialog extends ComfyDialog {
|
||||
{
|
||||
id: "clipspace_img_paste_mode",
|
||||
onchange: /* @__PURE__ */ __name((event) => {
|
||||
ComfyApp.clipspace["img_paste_mode"] = event.target.value;
|
||||
if (event.target && ComfyApp.clipspace) {
|
||||
ComfyApp.clipspace["img_paste_mode"] = event.target.value;
|
||||
}
|
||||
}, "onchange")
|
||||
},
|
||||
[
|
||||
@ -127,7 +133,7 @@ class ClipspaceDialog extends ComfyDialog {
|
||||
}
|
||||
}
|
||||
createImgPreview() {
|
||||
if (ComfyApp.clipspace.imgs) {
|
||||
if (ComfyApp.clipspace?.imgs) {
|
||||
return $el("img", { id: "clipspace_preview", ondragstart: /* @__PURE__ */ __name(() => false, "ondragstart") });
|
||||
} else return [];
|
||||
}
|
||||
@ -154,7 +160,7 @@ app.registerExtension({
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.clipspace = window.comfyAPI.clipspace || {};
|
||||
window.comfyAPI.clipspace.ClipspaceDialog = ClipspaceDialog;
|
||||
const ext$2 = {
|
||||
const ext$1 = {
|
||||
name: "Comfy.ContextMenuFilter",
|
||||
init() {
|
||||
const ctxMenu = LiteGraph.ContextMenu;
|
||||
@ -173,9 +179,9 @@ const ext$2 = {
|
||||
requestAnimationFrame(() => {
|
||||
const currentNode = LGraphCanvas.active_canvas.current_node;
|
||||
const clickedComboValue = currentNode.widgets?.filter(
|
||||
(w) => w.type === "combo" && w.options.values.length === values.length
|
||||
(w) => w.type === "combo" && w.options.values?.length === values.length
|
||||
).find(
|
||||
(w) => w.options.values.every((v, i) => v === values[i])
|
||||
(w) => w.options.values?.every((v, i) => v === values[i])
|
||||
)?.value;
|
||||
let selectedIndex = clickedComboValue ? values.findIndex((v) => v === clickedComboValue) : 0;
|
||||
if (selectedIndex < 0) {
|
||||
@ -244,7 +250,7 @@ const ext$2 = {
|
||||
filter.addEventListener("input", () => {
|
||||
const term = filter.value.toLocaleLowerCase();
|
||||
displayedItems = items.filter((item) => {
|
||||
const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term);
|
||||
const isVisible = !term || item.textContent?.toLocaleLowerCase().includes(term);
|
||||
item.style.display = isVisible ? "block" : "none";
|
||||
return isVisible;
|
||||
});
|
||||
@ -278,7 +284,7 @@ const ext$2 = {
|
||||
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext$2);
|
||||
app.registerExtension(ext$1);
|
||||
function stripComments(str) {
|
||||
return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g, "");
|
||||
}
|
||||
@ -338,7 +344,7 @@ app.registerExtension({
|
||||
if (text[start] === "(") openCount++;
|
||||
if (text[start] === ")") closeCount++;
|
||||
}
|
||||
if (start < 0) return false;
|
||||
if (start < 0) return null;
|
||||
openCount = 0;
|
||||
closeCount = 0;
|
||||
while (end < text.length) {
|
||||
@ -347,7 +353,7 @@ app.registerExtension({
|
||||
if (text[end] === ")") closeCount++;
|
||||
end++;
|
||||
}
|
||||
if (end === text.length) return false;
|
||||
if (end === text.length) return null;
|
||||
return { start: start + 1, end };
|
||||
}
|
||||
__name(findNearestEnclosure, "findNearestEnclosure");
|
||||
@ -1224,7 +1230,7 @@ class GroupNodeConfig {
|
||||
checkPrimitiveConnection(link, inputName, inputs) {
|
||||
const sourceNode = this.nodeData.nodes[link[0]];
|
||||
if (sourceNode.type === "PrimitiveNode") {
|
||||
const [sourceNodeId, _2, targetNodeId, __] = link;
|
||||
const [sourceNodeId, _, targetNodeId, __] = link;
|
||||
const primitiveDef = this.primitiveDefs[sourceNodeId];
|
||||
const targetWidget = inputs[inputName];
|
||||
const primitiveConfig = primitiveDef.input.required.value;
|
||||
@ -1619,7 +1625,7 @@ class GroupNodeHandler {
|
||||
return newNodes;
|
||||
};
|
||||
const getExtraMenuOptions = this.node.getExtraMenuOptions;
|
||||
this.node.getExtraMenuOptions = function(_2, options) {
|
||||
this.node.getExtraMenuOptions = function(_, options) {
|
||||
getExtraMenuOptions?.apply(this, arguments);
|
||||
let optionIndex = options.findIndex((o) => o.content === "Outputs");
|
||||
if (optionIndex === -1) optionIndex = options.length;
|
||||
@ -1793,7 +1799,7 @@ class GroupNodeHandler {
|
||||
} else if (innerNode.type === "Reroute") {
|
||||
const rerouteLinks = this.groupData.linksFrom[old.node.index];
|
||||
if (rerouteLinks) {
|
||||
for (const [_2, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||
const node = this.innerNodes[targetNodeId];
|
||||
const input = node.inputs[targetSlot];
|
||||
if (input.widget) {
|
||||
@ -2024,10 +2030,10 @@ function manageGroupNodes() {
|
||||
new ManageGroupDialog(app).show();
|
||||
}
|
||||
__name(manageGroupNodes, "manageGroupNodes");
|
||||
const id$3 = "Comfy.GroupNode";
|
||||
const id$2 = "Comfy.GroupNode";
|
||||
let globalDefs;
|
||||
const ext$1 = {
|
||||
name: id$3,
|
||||
const ext = {
|
||||
name: id$2,
|
||||
commands: [
|
||||
{
|
||||
id: "Comfy.GroupNode.ConvertSelectedNodesToGroupNode",
|
||||
@ -2097,7 +2103,7 @@ const ext$1 = {
|
||||
}
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext$1);
|
||||
app.registerExtension(ext);
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.groupNode = window.comfyAPI.groupNode || {};
|
||||
window.comfyAPI.groupNode.GroupNodeConfig = GroupNodeConfig;
|
||||
@ -2323,9 +2329,9 @@ app.registerExtension({
|
||||
};
|
||||
}
|
||||
});
|
||||
const id$2 = "Comfy.InvertMenuScrolling";
|
||||
const id$1 = "Comfy.InvertMenuScrolling";
|
||||
app.registerExtension({
|
||||
name: id$2,
|
||||
name: id$1,
|
||||
init() {
|
||||
const ctxMenu = LiteGraph.ContextMenu;
|
||||
const replace = /* @__PURE__ */ __name(() => {
|
||||
@ -2341,7 +2347,7 @@ app.registerExtension({
|
||||
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||
}, "replace");
|
||||
app.ui.settings.addSetting({
|
||||
id: id$2,
|
||||
id: id$1,
|
||||
category: ["Comfy", "Graph", "InvertMenuScrolling"],
|
||||
name: "Invert Context Menu Scrolling",
|
||||
type: "boolean",
|
||||
@ -2397,35 +2403,6 @@ app.registerExtension({
|
||||
window.addEventListener("keydown", keybindListener);
|
||||
}
|
||||
});
|
||||
const id$1 = "Comfy.LinkRenderMode";
|
||||
const ext = {
|
||||
name: id$1,
|
||||
async setup(app2) {
|
||||
app2.ui.settings.addSetting({
|
||||
id: id$1,
|
||||
category: ["Comfy", "Graph", "LinkRenderMode"],
|
||||
name: "Link Render Mode",
|
||||
defaultValue: 2,
|
||||
type: "combo",
|
||||
options: [
|
||||
{ value: LiteGraph.STRAIGHT_LINK, text: "Straight" },
|
||||
{ value: LiteGraph.LINEAR_LINK, text: "Linear" },
|
||||
{ value: LiteGraph.SPLINE_LINK, text: "Spline" },
|
||||
{ value: LiteGraph.HIDDEN_LINK, text: "Hidden" }
|
||||
],
|
||||
onChange(value) {
|
||||
app2.canvas.links_render_mode = +value;
|
||||
app2.canvas.setDirty(
|
||||
/* fg */
|
||||
false,
|
||||
/* bg */
|
||||
true
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext);
|
||||
function dataURLToBlob(dataURL) {
|
||||
const parts = dataURL.split(";base64,");
|
||||
const contentType = parts[0].split(":")[1];
|
||||
@ -3934,7 +3911,7 @@ app.registerExtension({
|
||||
};
|
||||
this.isVirtualNode = true;
|
||||
}
|
||||
getExtraMenuOptions(_2, options) {
|
||||
getExtraMenuOptions(_, options) {
|
||||
options.unshift(
|
||||
{
|
||||
content: (this.properties.showOutputText ? "Hide" : "Show") + " Type",
|
||||
@ -4157,7 +4134,7 @@ app.registerExtension({
|
||||
slot_types_default_in: {},
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
|
||||
var nodeId = nodeData.name;
|
||||
const inputs = nodeData["input"]["required"];
|
||||
const inputs = nodeData["input"]?.["required"];
|
||||
for (const inputKey in inputs) {
|
||||
var input = inputs[inputKey];
|
||||
if (typeof input[0] !== "string") continue;
|
||||
@ -4179,19 +4156,19 @@ app.registerExtension({
|
||||
nodeType.comfyClass
|
||||
);
|
||||
}
|
||||
var outputs = nodeData["output"];
|
||||
for (const key in outputs) {
|
||||
var type = outputs[key];
|
||||
if (!(type in this.slot_types_default_in)) {
|
||||
this.slot_types_default_in[type] = ["Reroute"];
|
||||
var outputs = nodeData["output"] ?? [];
|
||||
for (const el of outputs) {
|
||||
const type2 = el;
|
||||
if (!(type2 in this.slot_types_default_in)) {
|
||||
this.slot_types_default_in[type2] = ["Reroute"];
|
||||
}
|
||||
this.slot_types_default_in[type].push(nodeId);
|
||||
if (!(type in LiteGraph.registered_slot_out_types)) {
|
||||
LiteGraph.registered_slot_out_types[type] = { nodes: [] };
|
||||
this.slot_types_default_in[type2].push(nodeId);
|
||||
if (!(type2 in LiteGraph.registered_slot_out_types)) {
|
||||
LiteGraph.registered_slot_out_types[type2] = { nodes: [] };
|
||||
}
|
||||
LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass);
|
||||
if (!LiteGraph.slot_types_out.includes(type)) {
|
||||
LiteGraph.slot_types_out.push(type);
|
||||
LiteGraph.registered_slot_out_types[type2].nodes.push(nodeType.comfyClass);
|
||||
if (!LiteGraph.slot_types_out.includes(type2)) {
|
||||
LiteGraph.slot_types_out.push(type2);
|
||||
}
|
||||
}
|
||||
var maxNum = this.suggestionsNumber.value;
|
||||
@ -4276,7 +4253,7 @@ app.registerExtension({
|
||||
} else {
|
||||
w = node.size[0];
|
||||
h = node.size[1];
|
||||
let titleMode = node.constructor.title_mode;
|
||||
const titleMode = node.constructor.title_mode;
|
||||
if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) {
|
||||
h += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||
@ -4343,7 +4320,7 @@ app.registerExtension({
|
||||
});
|
||||
app.registerExtension({
|
||||
name: "Comfy.UploadImage",
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
|
||||
beforeRegisterNodeDef(nodeType, nodeData) {
|
||||
if (nodeData?.input?.required?.image?.[1]?.image_upload === true) {
|
||||
nodeData.input.required.upload = ["IMAGEUPLOAD"];
|
||||
}
|
||||
@ -4627,108 +4604,4 @@ app.registerExtension({
|
||||
};
|
||||
}
|
||||
});
|
||||
function getNodeSource(node) {
|
||||
const nodeDef = node.constructor.nodeData;
|
||||
if (!nodeDef) {
|
||||
return null;
|
||||
}
|
||||
const nodeDefStore = useNodeDefStore();
|
||||
return nodeDefStore.nodeDefsByName[nodeDef.name]?.nodeSource ?? null;
|
||||
}
|
||||
__name(getNodeSource, "getNodeSource");
|
||||
function isCoreNode(node) {
|
||||
return getNodeSource(node)?.type === NodeSourceType.Core;
|
||||
}
|
||||
__name(isCoreNode, "isCoreNode");
|
||||
function badgeTextVisible(node, badgeMode) {
|
||||
return badgeMode === NodeBadgeMode.None || isCoreNode(node) && badgeMode === NodeBadgeMode.HideBuiltIn;
|
||||
}
|
||||
__name(badgeTextVisible, "badgeTextVisible");
|
||||
function getNodeIdBadgeText(node, nodeIdBadgeMode) {
|
||||
return badgeTextVisible(node, nodeIdBadgeMode) ? "" : `#${node.id}`;
|
||||
}
|
||||
__name(getNodeIdBadgeText, "getNodeIdBadgeText");
|
||||
function getNodeSourceBadgeText(node, nodeSourceBadgeMode) {
|
||||
const nodeSource = getNodeSource(node);
|
||||
return badgeTextVisible(node, nodeSourceBadgeMode) ? "" : nodeSource?.badgeText ?? "";
|
||||
}
|
||||
__name(getNodeSourceBadgeText, "getNodeSourceBadgeText");
|
||||
function getNodeLifeCycleBadgeText(node, nodeLifeCycleBadgeMode) {
|
||||
let text = "";
|
||||
const nodeDef = node.constructor.nodeData;
|
||||
if (!nodeDef) {
|
||||
return "";
|
||||
}
|
||||
if (nodeDef.deprecated) {
|
||||
text = "[DEPR]";
|
||||
}
|
||||
if (nodeDef.experimental) {
|
||||
text = "[BETA]";
|
||||
}
|
||||
return badgeTextVisible(node, nodeLifeCycleBadgeMode) ? "" : text;
|
||||
}
|
||||
__name(getNodeLifeCycleBadgeText, "getNodeLifeCycleBadgeText");
|
||||
class NodeBadgeExtension {
|
||||
static {
|
||||
__name(this, "NodeBadgeExtension");
|
||||
}
|
||||
constructor(nodeIdBadgeMode = null, nodeSourceBadgeMode = null, nodeLifeCycleBadgeMode = null, colorPalette = null) {
|
||||
this.nodeIdBadgeMode = nodeIdBadgeMode;
|
||||
this.nodeSourceBadgeMode = nodeSourceBadgeMode;
|
||||
this.nodeLifeCycleBadgeMode = nodeLifeCycleBadgeMode;
|
||||
this.colorPalette = colorPalette;
|
||||
}
|
||||
name = "Comfy.NodeBadge";
|
||||
init(app2) {
|
||||
const settingStore = useSettingStore();
|
||||
this.nodeSourceBadgeMode = computed(
|
||||
() => settingStore.get("Comfy.NodeBadge.NodeSourceBadgeMode")
|
||||
);
|
||||
this.nodeIdBadgeMode = computed(
|
||||
() => settingStore.get("Comfy.NodeBadge.NodeIdBadgeMode")
|
||||
);
|
||||
this.nodeLifeCycleBadgeMode = computed(
|
||||
() => settingStore.get(
|
||||
"Comfy.NodeBadge.NodeLifeCycleBadgeMode"
|
||||
)
|
||||
);
|
||||
this.colorPalette = computed(
|
||||
() => getColorPalette(settingStore.get("Comfy.ColorPalette"))
|
||||
);
|
||||
watch(this.nodeSourceBadgeMode, () => {
|
||||
app2.graph.setDirtyCanvas(true, true);
|
||||
});
|
||||
watch(this.nodeIdBadgeMode, () => {
|
||||
app2.graph.setDirtyCanvas(true, true);
|
||||
});
|
||||
watch(this.nodeLifeCycleBadgeMode, () => {
|
||||
app2.graph.setDirtyCanvas(true, true);
|
||||
});
|
||||
}
|
||||
nodeCreated(node, app2) {
|
||||
node.badgePosition = BadgePosition.TopRight;
|
||||
node.badge_enabled = true;
|
||||
const badge = computed(
|
||||
() => new LGraphBadge({
|
||||
text: _.truncate(
|
||||
[
|
||||
getNodeIdBadgeText(node, this.nodeIdBadgeMode.value),
|
||||
getNodeLifeCycleBadgeText(
|
||||
node,
|
||||
this.nodeLifeCycleBadgeMode.value
|
||||
),
|
||||
getNodeSourceBadgeText(node, this.nodeSourceBadgeMode.value)
|
||||
].filter((s) => s.length > 0).join(" "),
|
||||
{
|
||||
length: 31
|
||||
}
|
||||
),
|
||||
fgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_FG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_FG_COLOR,
|
||||
bgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_BG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_BG_COLOR
|
||||
})
|
||||
);
|
||||
node.badges.push(() => badge.value);
|
||||
}
|
||||
}
|
||||
app.registerExtension(new NodeBadgeExtension());
|
||||
//# sourceMappingURL=index-DTOGNau5.js.map
|
||||
//# sourceMappingURL=index-DNRGG-ix.js.map
|
||||
1
comfy/web/assets/index-DNRGG-ix.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-DNRGG-ix.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-DTOGNau5.js.map
generated
vendored
1
comfy/web/assets/index-DTOGNau5.js.map
generated
vendored
File diff suppressed because one or more lines are too long
221
comfy/web/assets/index-CKl4cvVy.css → comfy/web/assets/index-HT1vecxT.css
generated
vendored
221
comfy/web/assets/index-CKl4cvVy.css → comfy/web/assets/index-HT1vecxT.css
generated
vendored
@ -1,160 +1,35 @@
|
||||
|
||||
:root {
|
||||
--red-600: #dc3545;
|
||||
}
|
||||
|
||||
.comfy-missing-nodes[data-v-0a88b934] {
|
||||
font-family: monospace;
|
||||
color: var(--red-600);
|
||||
padding: 1.5rem;
|
||||
.no-results-placeholder[data-v-a1e982e0] .p-card {
|
||||
background-color: var(--surface-ground);
|
||||
border-radius: var(--border-radius);
|
||||
box-shadow: var(--card-shadow);
|
||||
text-align: center;
|
||||
box-shadow: unset;
|
||||
}
|
||||
.warning-title[data-v-0a88b934] {
|
||||
margin-top: 0;
|
||||
.no-results-placeholder h3[data-v-a1e982e0] {
|
||||
color: var(--text-color);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.no-results-placeholder p[data-v-a1e982e0] {
|
||||
color: var(--text-color-secondary);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.warning-description[data-v-0a88b934] {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.missing-nodes-list[data-v-0a88b934] {
|
||||
|
||||
.comfy-missing-nodes[data-v-05a7c5eb] {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.missing-nodes-list.maximized[data-v-0a88b934] {
|
||||
max-height: unset;
|
||||
}
|
||||
.missing-node-item[data-v-0a88b934] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0.5rem;
|
||||
}
|
||||
.node-type[data-v-0a88b934] {
|
||||
font-weight: 600;
|
||||
color: var(--text-color);
|
||||
}
|
||||
.node-hint[data-v-0a88b934] {
|
||||
.node-hint[data-v-05a7c5eb] {
|
||||
margin-left: 0.5rem;
|
||||
font-style: italic;
|
||||
color: var(--text-color-secondary);
|
||||
}
|
||||
[data-v-0a88b934] .p-button {
|
||||
[data-v-05a7c5eb] .p-button {
|
||||
margin-left: auto;
|
||||
}
|
||||
.added-nodes-warning[data-v-0a88b934] {
|
||||
margin-top: 1rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
:root {
|
||||
--red-600: #dc3545;
|
||||
--green-500: #28a745;
|
||||
}
|
||||
|
||||
.comfy-missing-models[data-v-d0515260] {
|
||||
font-family: monospace;
|
||||
color: var(--red-600);
|
||||
padding: 1.5rem;
|
||||
background-color: var(--surface-ground);
|
||||
border-radius: var(--border-radius);
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
.warning-title[data-v-d0515260] {
|
||||
margin-top: 0;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.warning-description[data-v-d0515260] {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.warning-options[data-v-d0515260] {
|
||||
color: var(--fg-color);
|
||||
}
|
||||
.missing-models-list[data-v-d0515260] {
|
||||
.comfy-missing-models[data-v-936032d2] {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.missing-models-list.maximized[data-v-d0515260] {
|
||||
max-height: unset;
|
||||
}
|
||||
.missing-model-item[data-v-d0515260] {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
padding: 0.5rem;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
width: 100%;
|
||||
}
|
||||
.missing-model-item[data-v-d0515260]::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
height: 100%;
|
||||
width: var(--progress);
|
||||
background-color: var(--green-500);
|
||||
opacity: 0.2;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
.model-info[data-v-d0515260] {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
z-index: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin-right: 1rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
.model-details[data-v-d0515260] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.model-type[data-v-d0515260] {
|
||||
font-weight: 600;
|
||||
color: var(--text-color);
|
||||
margin-right: 0.5rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.model-hint[data-v-d0515260] {
|
||||
font-style: italic;
|
||||
color: var(--text-color-secondary);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.model-error[data-v-d0515260] {
|
||||
color: var(--red-600);
|
||||
font-size: 0.8rem;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
.model-action[data-v-d0515260] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: flex-end;
|
||||
z-index: 1;
|
||||
}
|
||||
.model-action-button[data-v-d0515260] {
|
||||
min-width: 80px;
|
||||
}
|
||||
.download-progress[data-v-d0515260],
|
||||
.download-complete[data-v-d0515260],
|
||||
.download-error[data-v-d0515260] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 80px;
|
||||
}
|
||||
.progress-text[data-v-d0515260] {
|
||||
font-size: 0.8rem;
|
||||
color: var(--text-color);
|
||||
}
|
||||
.download-complete i[data-v-d0515260],
|
||||
.download-error i[data-v-d0515260] {
|
||||
font-size: 1.2rem;
|
||||
}
|
||||
|
||||
.setting-input[data-v-04f094d4] .input-slider .p-inputnumber input,
|
||||
.setting-input[data-v-04f094d4] .input-slider .slider-part {
|
||||
@ -217,24 +92,6 @@
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
.no-results-placeholder[data-v-c19e9e10] {
|
||||
height: 100%;
|
||||
padding: 2rem;
|
||||
}
|
||||
.no-results-placeholder[data-v-c19e9e10] .p-card {
|
||||
background-color: var(--surface-ground);
|
||||
text-align: center;
|
||||
box-shadow: unset;
|
||||
}
|
||||
.no-results-placeholder h3[data-v-c19e9e10] {
|
||||
color: var(--text-color);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.no-results-placeholder p[data-v-c19e9e10] {
|
||||
color: var(--text-color-secondary);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.settings-tab-panels {
|
||||
padding-top: 0px !important;
|
||||
}
|
||||
@ -706,7 +563,7 @@
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.model-lib-model-icon-container[data-v-70b69131] {
|
||||
.model-lib-model-icon-container[data-v-be871f15] {
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
left: 0;
|
||||
@ -714,7 +571,7 @@
|
||||
vertical-align: top;
|
||||
width: 0px;
|
||||
}
|
||||
.model-lib-model-icon[data-v-70b69131] {
|
||||
.model-lib-model-icon[data-v-be871f15] {
|
||||
background-size: cover;
|
||||
background-position: center;
|
||||
display: inline-block;
|
||||
@ -725,7 +582,7 @@
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
[data-v-db5d3bcc] .pi-fake-spacer {
|
||||
[data-v-2f4635f2] .pi-fake-spacer {
|
||||
height: 1px;
|
||||
width: 16px;
|
||||
}
|
||||
@ -1808,6 +1665,9 @@ cursor: pointer;
|
||||
max-width: 3200px;
|
||||
}
|
||||
}
|
||||
.pointer-events-none{
|
||||
pointer-events: none;
|
||||
}
|
||||
.pointer-events-auto{
|
||||
pointer-events: auto;
|
||||
}
|
||||
@ -2049,9 +1909,18 @@ cursor: pointer;
|
||||
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
|
||||
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
|
||||
}
|
||||
.justify-self-end{
|
||||
justify-self: end;
|
||||
}
|
||||
.overflow-hidden{
|
||||
overflow: hidden;
|
||||
}
|
||||
.overflow-y-auto{
|
||||
overflow-y: auto;
|
||||
}
|
||||
.overflow-x-hidden{
|
||||
overflow-x: hidden;
|
||||
}
|
||||
.truncate{
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
@ -2066,6 +1935,9 @@ cursor: pointer;
|
||||
.whitespace-pre-line{
|
||||
white-space: pre-line;
|
||||
}
|
||||
.whitespace-pre-wrap{
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.text-wrap{
|
||||
text-wrap: wrap;
|
||||
}
|
||||
@ -2129,9 +2001,15 @@ cursor: pointer;
|
||||
.p-2{
|
||||
padding: 0.5rem;
|
||||
}
|
||||
.p-3{
|
||||
padding: 0.75rem;
|
||||
}
|
||||
.p-4{
|
||||
padding: 1rem;
|
||||
}
|
||||
.p-8{
|
||||
padding: 2rem;
|
||||
}
|
||||
.px-0{
|
||||
padding-left: 0px;
|
||||
padding-right: 0px;
|
||||
@ -2140,6 +2018,10 @@ cursor: pointer;
|
||||
padding-left: 0.5rem;
|
||||
padding-right: 0.5rem;
|
||||
}
|
||||
.px-4{
|
||||
padding-left: 1rem;
|
||||
padding-right: 1rem;
|
||||
}
|
||||
.py-0{
|
||||
padding-top: 0px;
|
||||
padding-bottom: 0px;
|
||||
@ -2436,25 +2318,6 @@ body {
|
||||
margin: 3px 3px 3px 4px;
|
||||
}
|
||||
|
||||
.comfy-menu-hamburger {
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
z-index: 9999;
|
||||
right: 10px;
|
||||
width: 30px;
|
||||
display: none;
|
||||
gap: 8px;
|
||||
flex-direction: column;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.comfy-menu-hamburger div {
|
||||
height: 3px;
|
||||
width: 100%;
|
||||
border-radius: 20px;
|
||||
background-color: white;
|
||||
}
|
||||
|
||||
.comfy-menu {
|
||||
font-size: 15px;
|
||||
position: absolute;
|
||||
4
comfy/web/assets/index-CaUteDIK.js → comfy/web/assets/index-nMMCMbCV.js
generated
vendored
4
comfy/web/assets/index-CaUteDIK.js → comfy/web/assets/index-nMMCMbCV.js
generated
vendored
@ -1,6 +1,6 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { aQ as script$2, g as openBlock, h as createElementBlock, m as mergeProps, A as createBaseVNode } from "./index-CFrRuGBA.js";
|
||||
import { b0 as script$2, g as openBlock, h as createElementBlock, m as mergeProps, A as createBaseVNode } from "./index-BNX_XOqh.js";
|
||||
var script$1 = {
|
||||
name: "BarsIcon",
|
||||
"extends": script$2
|
||||
@ -43,4 +43,4 @@ export {
|
||||
script as a,
|
||||
script$1 as s
|
||||
};
|
||||
//# sourceMappingURL=index-CaUteDIK.js.map
|
||||
//# sourceMappingURL=index-nMMCMbCV.js.map
|
||||
2
comfy/web/assets/index-CaUteDIK.js.map → comfy/web/assets/index-nMMCMbCV.js.map
generated
vendored
2
comfy/web/assets/index-CaUteDIK.js.map → comfy/web/assets/index-nMMCMbCV.js.map
generated
vendored
@ -1 +1 @@
|
||||
{"version":3,"file":"index-CaUteDIK.js","sources":["../../../../../node_modules/@primevue/icons/bars/index.mjs","../../../../../node_modules/@primevue/icons/plus/index.mjs"],"sourcesContent":["import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'BarsIcon',\n \"extends\": BaseIcon\n};\n\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _cache[0] || (_cache[0] = [createElementVNode(\"path\", {\n \"fill-rule\": \"evenodd\",\n \"clip-rule\": \"evenodd\",\n d: \"M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z\",\n fill: \"currentColor\"\n }, null, -1)]), 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n","import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'PlusIcon',\n \"extends\": BaseIcon\n};\n\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _cache[0] || (_cache[0] = [createElementVNode(\"path\", {\n d: \"M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z\",\n fill: \"currentColor\"\n }, null, -1)]), 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n"],"names":["script","BaseIcon","render","createElementVNode"],"mappings":";;;AAGG,IAACA,WAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWC;AACb;AAEA,SAASC,SAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,KAAK,GAAG,OAAO,CAAC,MAAM,OAAO,CAAC,IAAI,CAACC,gBAAmB,QAAQ;AAAA,IACpE,aAAa;AAAA,IACb,aAAa;AAAA,IACb,GAAG;AAAA,IACH,MAAM;AAAA,EACP,GAAE,MAAM,EAAE,CAAC,IAAI,EAAE;AACpB;AAbSD;AAeTF,SAAO,SAASE;ACpBb,IAAC,SAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWD;AACb;AAEA,SAAS,OAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,KAAK,GAAG,OAAO,CAAC,MAAM,OAAO,CAAC,IAAI,CAACE,gBAAmB,QAAQ;AAAA,IACpE,GAAG;AAAA,IACH,MAAM;AAAA,EACP,GAAE,MAAM,EAAE,CAAC,IAAI,EAAE;AACpB;AAXS;AAaT,OAAO,SAAS;","x_google_ignoreList":[0,1]}
|
||||
{"version":3,"file":"index-nMMCMbCV.js","sources":["../../../../../node_modules/@primevue/icons/bars/index.mjs","../../../../../node_modules/@primevue/icons/plus/index.mjs"],"sourcesContent":["import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'BarsIcon',\n \"extends\": BaseIcon\n};\n\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _cache[0] || (_cache[0] = [createElementVNode(\"path\", {\n \"fill-rule\": \"evenodd\",\n \"clip-rule\": \"evenodd\",\n d: \"M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z\",\n fill: \"currentColor\"\n }, null, -1)]), 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n","import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'PlusIcon',\n \"extends\": BaseIcon\n};\n\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _cache[0] || (_cache[0] = [createElementVNode(\"path\", {\n d: \"M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z\",\n fill: \"currentColor\"\n }, null, -1)]), 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n"],"names":["script","BaseIcon","render","createElementVNode"],"mappings":";;;AAGG,IAACA,WAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWC;AACb;AAEA,SAASC,SAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,KAAK,GAAG,OAAO,CAAC,MAAM,OAAO,CAAC,IAAI,CAACC,gBAAmB,QAAQ;AAAA,IACpE,aAAa;AAAA,IACb,aAAa;AAAA,IACb,GAAG;AAAA,IACH,MAAM;AAAA,EACP,GAAE,MAAM,EAAE,CAAC,IAAI,EAAE;AACpB;AAbSD;AAeTF,SAAO,SAASE;ACpBb,IAAC,SAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWD;AACb;AAEA,SAAS,OAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,KAAK,GAAG,OAAO,CAAC,MAAM,OAAO,CAAC,IAAI,CAACE,gBAAmB,QAAQ;AAAA,IACpE,GAAG;AAAA,IACH,MAAM;AAAA,EACP,GAAE,MAAM,EAAE,CAAC,IAAI,EAAE;AACpB;AAXS;AAaT,OAAO,SAAS;","x_google_ignoreList":[0,1]}
|
||||
@ -1,6 +1,6 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { bp as api, bu as $el } from "./index-CFrRuGBA.js";
|
||||
import { bB as api, bG as $el } from "./index-BNX_XOqh.js";
|
||||
function createSpinner() {
|
||||
const div = document.createElement("div");
|
||||
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
|
||||
@ -126,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||
export {
|
||||
UserSelectionScreen
|
||||
};
|
||||
//# sourceMappingURL=userSelection-vhU1ykfH.js.map
|
||||
//# sourceMappingURL=userSelection-C7IbQlVC.js.map
|
||||
File diff suppressed because one or more lines are too long
@ -1,6 +1,6 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { e as LGraphNode, c as app, bz as applyTextReplacements, by as ComfyWidgets, bF as addValueControlWidgets, k as LiteGraph } from "./index-CFrRuGBA.js";
|
||||
import { e as LGraphNode, c as app, bL as applyTextReplacements, bK as ComfyWidgets, bM as addValueControlWidgets, k as LiteGraph } from "./index-BNX_XOqh.js";
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = [
|
||||
"STRING",
|
||||
@ -753,4 +753,4 @@ export {
|
||||
mergeIfValid,
|
||||
setWidgetConfig
|
||||
};
|
||||
//# sourceMappingURL=widgetInputs-B4bHTYzE.js.map
|
||||
//# sourceMappingURL=widgetInputs-B62b6cfl.js.map
|
||||
File diff suppressed because one or more lines are too long
4
comfy/web/index.html
vendored
4
comfy/web/index.html
vendored
@ -6,8 +6,8 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||
<script type="module" crossorigin src="./assets/index-CFrRuGBA.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-CKl4cvVy.css">
|
||||
<script type="module" crossorigin src="./assets/index-BNX_XOqh.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-HT1vecxT.css">
|
||||
</head>
|
||||
<body class="litegraph grid">
|
||||
<div id="vue-app"></div>
|
||||
|
||||
@ -2,6 +2,7 @@ import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy.component_model.tensor_types import Latent
|
||||
from .nodes_post_processing import gaussian_kernel
|
||||
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
@ -191,6 +192,138 @@ class LatentAddNoiseChannels:
|
||||
return (s,)
|
||||
|
||||
|
||||
class LatentApplyOperation:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, operation):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = operation(latent=s1)
|
||||
return (samples_out,)
|
||||
|
||||
|
||||
class LatentApplyOperationCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, operation):
|
||||
m = model.clone()
|
||||
|
||||
def pre_cfg_function(args):
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) == 2:
|
||||
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||
else:
|
||||
conds_out[0] = operation(latent=conds_out[0])
|
||||
return conds_out
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
|
||||
|
||||
class LatentOperationTonemapReinhard:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, multiplier):
|
||||
def tonemap_reinhard(latent, **kwargs):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:, None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1, 2, 3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
# reinhard
|
||||
latent_vector_magnitude *= (1.0 / top)
|
||||
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||
new_magnitude *= top
|
||||
|
||||
return normalized_latent * new_magnitude
|
||||
|
||||
return (tonemap_reinhard,)
|
||||
|
||||
|
||||
class LatentOperationSharpen:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"sharpen_radius": ("INT", {
|
||||
"default": 9,
|
||||
"min": 1,
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.01
|
||||
}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, sharpen_radius, sigma, alpha):
|
||||
def sharpen(latent, **kwargs):
|
||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:, None]
|
||||
normalized_latent = latent / luminance
|
||||
channels = latent.shape[1]
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||
center = kernel_size // 2
|
||||
|
||||
kernel *= alpha * -10
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), 'reflect')
|
||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
|
||||
return luminance * sharpened
|
||||
|
||||
return (sharpen,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
@ -199,4 +332,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentBatch": LatentBatch,
|
||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||
"LatentAddNoiseChannels": LatentAddNoiseChannels,
|
||||
"LatentApplyOperation": LatentApplyOperation,
|
||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||
"LatentOperationSharpen": LatentOperationSharpen,
|
||||
}
|
||||
|
||||
@ -88,10 +88,9 @@ class LoraSave:
|
||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {"model_diff": ("MODEL",),
|
||||
"text_encoder_diff": ("CLIP",)},
|
||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
@ -121,3 +120,7 @@ class LoraSave:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraSave": LoraSave
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraSave": "Extract and Save Lora"
|
||||
}
|
||||
|
||||
@ -103,10 +103,34 @@ class ModelMergeFlux1(ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeSD35_Large(ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embed."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["context_embedder."] = argument
|
||||
arg_dict["y_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
}
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sd
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
@ -36,6 +40,7 @@ class EmptySD3LatentImage:
|
||||
return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@ -91,23 +96,86 @@ class CLIPTextEncodeSD3:
|
||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"control_net": ("CONTROL_NET",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
DEPRECATED = True
|
||||
|
||||
|
||||
class SkipLayerGuidanceSD3:
|
||||
'''
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
Experimental implementation by Dango233@StabilityAI.
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
||||
if layers == "" or layers == None:
|
||||
return (model,)
|
||||
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
return cfg_result
|
||||
|
||||
layers = re.findall(r'\d+', layers)
|
||||
layers = [int(i) for i in layers]
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripleCLIPLoader": TripleCLIPLoader,
|
||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||
"ControlNetApplySD3": ControlNetApplySD3,
|
||||
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
26
comfy_extras/nodes_mochi.py
Normal file
26
comfy_extras/nodes_mochi.py
Normal file
@ -0,0 +1,26 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
class EmptyMochiLatentVideo:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/mochi"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user