mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-05 11:40:53 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
56d83f6359
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if self.use_additional_t_cond:
|
||||
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
default_ref_method="index",
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
use_additional_t_cond=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
use_additional_t_cond=use_additional_t_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
t_len = t
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
if t_len > 1:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
additional_t_cond=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
index = 0
|
||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
negative_ref_method = ref_method == "negative_index"
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif negative_ref_method:
|
||||
index -= 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
||||
@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
input_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
output_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||
dit_config["use_additional_t_cond"] = True
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
|
||||
36
comfy/sd.py
36
comfy/sd.py
@ -321,6 +321,7 @@ class VAE:
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
@ -435,6 +436,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
@ -546,7 +548,9 @@ class VAE:
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
@ -582,6 +586,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
@ -690,17 +695,28 @@ class VAE:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
if self.crop_input:
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
if pixels.shape[-1] > self.output_channels:
|
||||
pixels = pixels[..., :self.output_channels]
|
||||
elif pixels.shape[-1] < self.output_channels:
|
||||
if self.pad_channel_value is not None:
|
||||
if isinstance(self.pad_channel_value, str):
|
||||
mode = self.pad_channel_value
|
||||
value = None
|
||||
else:
|
||||
mode = "constant"
|
||||
value = self.pad_channel_value
|
||||
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
|
||||
@ -5,6 +5,7 @@ import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
import math
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode):
|
||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentCutToBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["t", "x", "y"]),
|
||||
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
|
||||
if "x" in dim:
|
||||
dim = s1.ndim - 1
|
||||
elif "y" in dim:
|
||||
dim = s1.ndim - 2
|
||||
elif "t" in dim:
|
||||
dim = s1.ndim - 3
|
||||
|
||||
if dim < 2:
|
||||
return io.NodeOutput(samples)
|
||||
|
||||
s = s1.movedim(dim, 1)
|
||||
if s.shape[1] < slice_size:
|
||||
slice_size = s.shape[1]
|
||||
elif s.shape[1] % slice_size != 0:
|
||||
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
|
||||
new_shape = [-1, slice_size] + list(s.shape[2:])
|
||||
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension):
|
||||
LatentInterpolate,
|
||||
LatentConcat,
|
||||
LatentCut,
|
||||
LatentCutToBatch,
|
||||
LatentBatch,
|
||||
LatentBatchSeedBehavior,
|
||||
LatentApplyOperation,
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -343,7 +343,7 @@ class VAEEncode:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
t = vae.encode(pixels)
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeTiled:
|
||||
@ -361,7 +361,7 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
return ({"samples": t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.34.9
|
||||
comfyui-workflow-templates==0.7.59
|
||||
comfyui-workflow-templates==0.7.60
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Loading…
Reference in New Issue
Block a user