diff --git a/README.md b/README.md index b15f58430..6b5b6bf30 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) +- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - Asynchronous Queue system diff --git a/comfy/api_server/services/terminal_service.py b/comfy/api_server/services/terminal_service.py index 8ae6b5549..482d016b7 100644 --- a/comfy/api_server/services/terminal_service.py +++ b/comfy/api_server/services/terminal_service.py @@ -1,4 +1,5 @@ import os +import shutil from ...app.logger import on_flush @@ -11,15 +12,33 @@ class TerminalService: self.subscriptions = set() on_flush(self.send_messages) + def get_terminal_size(self): + try: + size = os.get_terminal_size() + return (size.columns, size.lines) + except OSError: + try: + size = shutil.get_terminal_size() + return (size.columns, size.lines) + except OSError: + return (80, 24) # fallback to 80x24 + def update_size(self): - sz = os.get_terminal_size() + columns, lines = self.get_terminal_size() changed = False +<<<<<<< HEAD:comfy/api_server/services/terminal_service.py if sz.columns != self.cols: self.cols = sz.columns changed = True +======= + + if columns != self.cols: + self.cols = columns + changed = True +>>>>>>> 6e8cdcd3cb542ba9eb5a5e5a420eff06f59dd268:api_server/services/terminal_service.py - if sz.lines != self.rows: - self.rows = sz.lines + if lines != self.rows: + self.rows = lines changed = True if changed: diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 047c5da7d..ae9a0e16e 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module): ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), "gelu": torch.nn.functional.gelu, + "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), } class CLIPMLP(torch.nn.Module): @@ -140,27 +141,35 @@ class CLIPTextModel(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): super().__init__() - self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + + num_patches = (image_size // patch_size) ** 2 + if model_type == "siglip_vision_model": + self.class_embedding = None + patch_bias = True + else: + num_patches = num_patches + 1 + self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + patch_bias = False self.patch_embedding = operations.Conv2d( in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - bias=False, + bias=patch_bias, dtype=dtype, device=device ) - num_patches = (image_size // patch_size) ** 2 - num_positions = num_patches + 1 - self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + ops.cast_to_input(self.position_embedding.weight, embeds) + if self.class_embedding is not None: + embeds = torch.cat([ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + return embeds + ops.cast_to_input(self.position_embedding.weight, embeds) class CLIPVision(torch.nn.Module): @@ -171,9 +180,15 @@ class CLIPVision(torch.nn.Module): heads = config_dict["num_attention_heads"] intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] + model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) - self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type == "siglip_vision_model": + self.pre_layrnorm = lambda a: a + self.output_layernorm = True + else: + self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.output_layernorm = False self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.post_layernorm = operations.LayerNorm(embed_dim) @@ -182,14 +197,21 @@ class CLIPVision(torch.nn.Module): x = self.pre_layrnorm(x) #TODO: attention_mask? x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) - pooled_output = self.post_layernorm(x[:, 0, :]) + if self.output_layernorm: + x = self.post_layernorm(x) + pooled_output = x + else: + pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output class CLIPVisionModelProjection(torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() self.vision_model = CLIPVision(config_dict, dtype, device, operations) - self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + if "projection_dim" in config_dict: + self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + else: + self.visual_projection = lambda a: a def forward(self, *args, **kwargs): x = self.vision_model(*args, **kwargs) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index b92b7b69e..a2e5262bb 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -17,9 +17,9 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224): - mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) - std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]): + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): scale = (size / min(image.shape[2], image.shape[3])) @@ -44,6 +44,8 @@ class ClipVisionModel(): raise ValueError(f"json_config had invalid value={json_config}") self.image_size = config.get("image_size", 224) + self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) + self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() self.dtype = model_management.text_encoder_dtype(self.load_device) @@ -59,7 +61,7 @@ class ClipVisionModel(): def encode_image(self, image): load_models_gpu([self.patcher]) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float() out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() @@ -102,7 +104,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = files.get_path_as_dict(None, "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: - if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json") else: json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") diff --git a/comfy/clip_vision_siglip_384.json b/comfy/clip_vision_siglip_384.json new file mode 100644 index 000000000..532e03ac1 --- /dev/null +++ b/comfy/clip_vision_siglip_384.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f9fd16d8c..73f40de95 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,5 +1,6 @@ import torch + class LatentFormat: scale_factor = 1.0 latent_channels = 4 @@ -13,33 +14,36 @@ class LatentFormat: def process_out(self, latent): return latent / self.scale_factor + class SD15(LatentFormat): def __init__(self, scale_factor=0.18215): self.scale_factor = scale_factor self.latent_rgb_factors = [ - # R G B - [ 0.3512, 0.2297, 0.3227], - [ 0.3250, 0.4974, 0.2350], - [-0.2829, 0.1762, 0.2721], - [-0.2120, -0.2616, -0.7177] - ] + # R G B + [0.3512, 0.2297, 0.3227], + [0.3250, 0.4974, 0.2350], + [-0.2829, 0.1762, 0.2721], + [-0.2120, -0.2616, -0.7177] + ] self.taesd_decoder_name = "taesd_decoder" + class SDXL(LatentFormat): scale_factor = 0.13025 def __init__(self): self.latent_rgb_factors = [ - # R G B - [ 0.3651, 0.4232, 0.4341], - [-0.2533, -0.0042, 0.1068], - [ 0.1076, 0.1111, -0.0362], - [-0.3165, -0.2492, -0.2188] - ] - self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011] + # R G B + [0.3651, 0.4232, 0.4341], + [-0.2533, -0.0042, 0.1068], + [0.1076, 0.1111, -0.0362], + [-0.3165, -0.2492, -0.2188] + ] + self.latent_rgb_factors_bias = [0.1084, -0.0175, -0.0011] self.taesd_decoder_name = "taesdxl_decoder" + class SDXL_Playground_2_5(LatentFormat): def __init__(self): self.scale_factor = 0.5 @@ -47,12 +51,12 @@ class SDXL_Playground_2_5(LatentFormat): self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) self.latent_rgb_factors = [ - # R G B - [ 0.3920, 0.4054, 0.4549], - [-0.2634, -0.0196, 0.0653], - [ 0.0568, 0.1687, -0.0755], - [-0.3112, -0.2359, -0.2076] - ] + # R G B + [0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] self.taesd_decoder_name = "taesdxl_decoder" def process_in(self, latent): @@ -71,63 +75,68 @@ class SD_X4(LatentFormat): self.scale_factor = 0.08333 self.latent_rgb_factors = [ [-0.2340, -0.3863, -0.3257], - [ 0.0994, 0.0885, -0.0908], + [0.0994, 0.0885, -0.0908], [-0.2833, -0.2349, -0.3741], - [ 0.2523, -0.0055, -0.1651] + [0.2523, -0.0055, -0.1651] ] + class SC_Prior(LatentFormat): latent_channels = 16 + def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ [-0.0326, -0.0204, -0.0127], - [-0.1592, -0.0427, 0.0216], - [ 0.0873, 0.0638, -0.0020], - [-0.0602, 0.0442, 0.1304], - [ 0.0800, -0.0313, -0.1796], + [-0.1592, -0.0427, 0.0216], + [0.0873, 0.0638, -0.0020], + [-0.0602, 0.0442, 0.1304], + [0.0800, -0.0313, -0.1796], [-0.0810, -0.0638, -0.1581], - [ 0.1791, 0.1180, 0.0967], - [ 0.0740, 0.1416, 0.0432], + [0.1791, 0.1180, 0.0967], + [0.0740, 0.1416, 0.0432], [-0.1745, -0.1888, -0.1373], - [ 0.2412, 0.1577, 0.0928], - [ 0.1908, 0.0998, 0.0682], - [ 0.0209, 0.0365, -0.0092], - [ 0.0448, -0.0650, -0.1728], + [0.2412, 0.1577, 0.0928], + [0.1908, 0.0998, 0.0682], + [0.0209, 0.0365, -0.0092], + [0.0448, -0.0650, -0.1728], [-0.1658, -0.1045, -0.1308], - [ 0.0542, 0.1545, 0.1325], + [0.0542, 0.1545, 0.1325], [-0.0352, -0.1672, -0.2541] ] + class SC_B(LatentFormat): def __init__(self): self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ - [ 0.1121, 0.2006, 0.1023], + [0.1121, 0.2006, 0.1023], [-0.2093, -0.0222, -0.0195], - [-0.3087, -0.1535, 0.0366], - [ 0.0290, -0.1574, -0.4078] + [-0.3087, -0.1535, 0.0366], + [0.0290, -0.1574, -0.4078] ] + class SD3(LatentFormat): latent_channels = 16 + def __init__(self): self.scale_factor = 1.5305 self.shift_factor = 0.0609 self.latent_rgb_factors = [ - [-0.0922, -0.0175, 0.0749], - [ 0.0311, 0.0633, 0.0954], - [ 0.1994, 0.0927, 0.0458], - [ 0.0856, 0.0339, 0.0902], - [ 0.0587, 0.0272, -0.0496], - [-0.0006, 0.1104, 0.0309], - [ 0.0978, 0.0306, 0.0427], - [-0.0042, 0.1038, 0.1358], - [-0.0194, 0.0020, 0.0669], - [-0.0488, 0.0130, -0.0268], - [ 0.0922, 0.0988, 0.0951], - [-0.0278, 0.0524, -0.0542], - [ 0.0332, 0.0456, 0.0895], + [-0.0922, -0.0175, 0.0749], + [0.0311, 0.0633, 0.0954], + [0.1994, 0.0927, 0.0458], + [0.0856, 0.0339, 0.0902], + [0.0587, 0.0272, -0.0496], + [-0.0006, 0.1104, 0.0309], + [0.0978, 0.0306, 0.0427], + [-0.0042, 0.1038, 0.1358], + [-0.0194, 0.0020, 0.0669], + [-0.0488, 0.0130, -0.0268], + [0.0922, 0.0988, 0.0951], + [-0.0278, 0.0524, -0.0542], + [0.0332, 0.0456, 0.0895], [-0.0069, -0.0030, -0.0810], [-0.0596, -0.0465, -0.0293], [-0.1448, -0.1463, -0.1189] @@ -141,28 +150,31 @@ class SD3(LatentFormat): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor + class StableAudio1(LatentFormat): latent_channels = 64 + class Flux(SD3): latent_channels = 16 + def __init__(self): self.scale_factor = 0.3611 self.shift_factor = 0.1159 - self.latent_rgb_factors =[ - [-0.0346, 0.0244, 0.0681], - [ 0.0034, 0.0210, 0.0687], - [ 0.0275, -0.0668, -0.0433], - [-0.0174, 0.0160, 0.0617], - [ 0.0859, 0.0721, 0.0329], - [ 0.0004, 0.0383, 0.0115], - [ 0.0405, 0.0861, 0.0915], + self.latent_rgb_factors = [ + [-0.0346, 0.0244, 0.0681], + [0.0034, 0.0210, 0.0687], + [0.0275, -0.0668, -0.0433], + [-0.0174, 0.0160, 0.0617], + [0.0859, 0.0721, 0.0329], + [0.0004, 0.0383, 0.0115], + [0.0405, 0.0861, 0.0915], [-0.0236, -0.0185, -0.0259], - [-0.0245, 0.0250, 0.1180], - [ 0.1008, 0.0755, -0.0421], - [-0.0515, 0.0201, 0.0011], - [ 0.0428, -0.0012, -0.0036], - [ 0.0817, 0.0765, 0.0749], + [-0.0245, 0.0250, 0.1180], + [0.1008, 0.0755, -0.0421], + [-0.0515, 0.0201, 0.0011], + [0.0428, -0.0012, -0.0036], + [0.0817, 0.0765, 0.0749], [-0.1264, -0.0522, -0.1103], [-0.0280, -0.0881, -0.0499], [-0.1262, -0.0982, -0.0778] @@ -176,6 +188,7 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor + class Mochi(LatentFormat): latent_channels = 12 @@ -190,22 +203,22 @@ class Mochi(LatentFormat): 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1) - self.latent_rgb_factors =[ - [-0.0069, -0.0045, 0.0018], - [ 0.0154, -0.0692, -0.0274], - [ 0.0333, 0.0019, 0.0206], - [-0.1390, 0.0628, 0.1678], - [-0.0725, 0.0134, -0.1898], - [ 0.0074, -0.0270, -0.0209], + self.latent_rgb_factors = [ + [-0.0069, -0.0045, 0.0018], + [0.0154, -0.0692, -0.0274], + [0.0333, 0.0019, 0.0206], + [-0.1390, 0.0628, 0.1678], + [-0.0725, 0.0134, -0.1898], + [0.0074, -0.0270, -0.0209], [-0.0176, -0.0277, -0.0221], - [ 0.5294, 0.5204, 0.3852], + [0.5294, 0.5204, 0.3852], [-0.0326, -0.0446, -0.0143], - [-0.0659, 0.0153, -0.0153], - [ 0.0185, -0.0217, 0.0014], + [-0.0659, 0.0153, -0.0153], + [0.0185, -0.0217, 0.0014], [-0.0396, -0.0495, -0.0281] ] self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = None # TODO def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -216,3 +229,7 @@ class Mochi(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + + +class LTXV(LatentFormat): + latent_channels = 128 diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index fd16d0a36..c38a8830f 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): + patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device + context = kwargs["context"] info = { "hidden_states": [], @@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module): if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) + blocks_replace = patches_replace.get("dit", {}) # Iterate over the transformer layers - for layer in self.layers: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + for i, layer in enumerate(self.layers): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + x = out["img"] + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: @@ -876,10 +888,7 @@ class AudioDiffusionTransformer(nn.Module): mask=None, return_info=False, control=None, - transformer_options=None, **kwargs): - if transformer_options is None: - transformer_options = {} return self._forward( x, timestep, diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 286aa398d..cbe15edc7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -20,6 +20,7 @@ from .. import common_dit @dataclass class FluxParams: in_channels: int + out_channels: int vec_in_dim: int context_in_dim: int hidden_size: int @@ -29,6 +30,7 @@ class FluxParams: depth_single_blocks: int axes_dim: list theta: int + patch_size: int qkv_bias: bool guidance_embed: bool @@ -44,8 +46,9 @@ class Flux(nn.Module): self.dtype = dtype params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels * 2 * 2 - self.out_channels = self.in_channels + self.patch_size = params.patch_size + self.in_channels = params.in_channels * params.patch_size * params.patch_size + self.out_channels = params.out_channels * params.patch_size * params.patch_size if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -166,7 +169,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape - patch_size = 2 + patch_size = self.patch_size x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/flux/redux.py b/comfy/ldm/flux/redux.py new file mode 100644 index 000000000..527e83164 --- /dev/null +++ b/comfy/ldm/flux/redux.py @@ -0,0 +1,25 @@ +import torch +import comfy.ops + +ops = comfy.ops.manual_cast + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + + self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) + return projected_x diff --git a/comfy/ldm/lightricks/__init__.py b/comfy/ldm/lightricks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py new file mode 100644 index 000000000..7e32276be --- /dev/null +++ b/comfy/ldm/lightricks/model.py @@ -0,0 +1,503 @@ +import torch +from torch import nn + +from ..common_dit import rms_norm +from ..genmo.joint_model.layers import RMSNorm +from einops import rearrange +import math +from typing import Dict, Optional, Tuple + +from .symmetric_patchifier import SymmetricPatchifier +from ..modules.attention import optimized_attention, optimized_attention_masked + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + dtype=None, device=None, operations=None, + ): + super().__init__() + + self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device) + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device) + else: + self.cond_proj = None + + self.act = nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + + if post_act_fn is None: + self.post_act = None + # else: + # self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + ) + + self.silu = nn.SiLU() + self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class GELU_approx(nn.Module): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device) + + def forward(self, x): + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + super().__init__() + inner_dim = int(dim * mult) + project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + ) + + def forward(self, x): + return self.net(x) + + +def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + self.attn_precision = attn_precision + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + + self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + + self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None, pe=None): + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe) + + if mask is None: + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + else: + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + super().__init__() + + self.attn_precision = attn_precision + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) + + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + + x += self.attn1(rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + + x += self.attn2(x, context=context, mask=attention_mask) + + y = rms_norm(x) * (1 + scale_mlp) + shift_mlp + x += self.ff(y) * gate_mlp + + return x + +def get_fractional_positions(indices_grid, max_pos): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + +def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): + dtype = torch.float32 #self.dtype + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + + start = 1 + end = theta + device = fractional_positions.device + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + + indices = indices * math.pi / 2 + + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class LTXVModel(torch.nn.Module): + def __init__(self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + + caption_channels=4096, + num_layers=28, + + + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + ) + + # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + # attn_precision=attn_precision, + dtype=dtype, device=device, operations=operations + ) + for d in range(num_layers) + ] + ) + + self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) + self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + + self.patchifier = SymmetricPatchifier(1) + + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): + indices_grid = self.patchifier.get_grid( + orig_num_frames=x.shape[2], + orig_height=x.shape[3], + orig_width=x.shape[4], + batch_size=x.shape[0], + scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate + device=x.device, + ) + + if guiding_latent is not None: + ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) + input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) + ts *= input_ts + ts[:, :, 0] = 0.0 + timestep = self.patchifier.patchify(ts) + input_x = x.clone() + x[:, :, 0] = guiding_latent[:, :, 0] + + orig_shape = list(x.shape) + + x = self.patchifier.patchify(x) + + x = self.patchify_proj(x) + timestep = timestep * 1000.0 + + attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) + attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this + # attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype) + + pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) + + batch_size = x.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=x.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view( + batch_size, -1, x.shape[-1] + ) + + for block in self.transformer_blocks: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) + # Modulation + x = x * (1 + scale) + shift + x = self.proj_out(x) + + x = self.patchifier.unpatchify( + latents=x, + output_height=orig_shape[3], + output_width=orig_shape[4], + output_num_frames=orig_shape[2], + out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), + ) + + if guiding_latent is not None: + x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] + + # print("res", x) + return x diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py new file mode 100644 index 000000000..51ce50589 --- /dev/null +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from einops import rearrange +from torch import Tensor + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Patchifier(ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify( + self, latents: Tensor, frame_rates: Tensor, scale_grid: bool + ) -> Tuple[Tensor, Tensor]: + pass + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_grid( + self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device + ): + f = orig_num_frames // self._patch_size[0] + h = orig_height // self._patch_size[1] + w = orig_width // self._patch_size[2] + grid_h = torch.arange(h, dtype=torch.float32, device=device) + grid_w = torch.arange(w, dtype=torch.float32, device=device) + grid_f = torch.arange(f, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w) + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if scale_grid is not None: + for i in range(3): + if isinstance(scale_grid[i], Tensor): + scale = append_dims(scale_grid[i], grid.ndim - 1) + else: + scale = scale_grid[i] + grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] + + grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) + return grid + + +class SymmetricPatchifier(Patchifier): + def patchify( + self, + latents: Tensor, + ) -> Tuple[Tensor, Tensor]: + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q) ", + f=output_num_frames, + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/comfy/ldm/lightricks/vae/__init__.py b/comfy/ldm/lightricks/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py new file mode 100644 index 000000000..dbb852218 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -0,0 +1,63 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn +from ....ops import disable_weight_init as ops + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = ops.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode="zeros", + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py new file mode 100644 index 000000000..33b2c2d4f --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -0,0 +1,698 @@ +import torch +from torch import nn +from functools import partial +import math +from einops import rearrange +from typing import Any, Mapping, Optional, Tuple, Union, List +from .conv_nd_factory import make_conv_nd, make_linear_nd +from .pixel_norm import PixelNorm + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks=[("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, output_channel, conv_out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks=[("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(2, 1, 1) + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(1, 2, 2) + ) + elif block_name == "compress_all": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, output_channel, out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + # assert target_shape is not None, "target_shape must be provided" + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, hidden_states: torch.FloatTensor, causal: bool = True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states, causal=causal) + + return hidden_states + + +class DepthToSpaceUpsample(nn.Module): + def __init__(self, dims, in_channels, stride, residual=False): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + ) + self.residual = residual + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + input_tensor = self.norm3(input_tensor) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + self.register_buffer("mean-of-stds", torch.empty(128)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) + self.register_buffer("channel", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) + +class VideoVAE(nn.Module): + def __init__(self): + super().__init__() + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + } + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + + self.encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + ) + + self.decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + ) + + self.per_channel_statistics = processor() + + def encode(self, x): + means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + return self.per_channel_statistics.normalize(means) + + def decode(self, x): + return self.decoder(self.per_channel_statistics.un_normalize(x)) + diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py new file mode 100644 index 000000000..228b83620 --- /dev/null +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -0,0 +1,82 @@ +from typing import Tuple, Union + +import torch + +from .dual_conv3d import DualConv3d +from .causal_conv3d import CausalConv3d +from ....ops import disable_weight_init as ops + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, +): + if dims == 2: + return ops.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + return ops.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return ops.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return ops.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py new file mode 100644 index 000000000..6bd54c0a6 --- /dev/null +++ b/comfy/ldm/lightricks/vae/dual_conv3d.py @@ -0,0 +1,195 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/comfy/ldm/lightricks/vae/pixel_norm.py b/comfy/ldm/lightricks/vae/pixel_norm.py new file mode 100644 index 000000000..9bc3ea60e --- /dev/null +++ b/comfy/ldm/lightricks/vae/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 28172d1a3..6fd9daf12 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -311,7 +311,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if len(mask.shape) == 2: s1 += mask[i:end] else: - s1 += mask[:, i:end] + if mask.shape[1] == 1: + s1 += mask + else: + s1 += mask[:, i:end] s2 = s1.softmax(dim=-1).to(v.dtype) del s1 @@ -373,10 +376,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh ) if mask is not None: - pad = 8 - q.shape[1] % 8 - mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) - mask_out[:, :, :mask.shape[-1]] = mask - mask = mask_out[:, :, :mask.shape[-1]] + pad = 8 - mask.shape[-1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device) + mask_out[..., :mask.shape[-1]] = mask + mask = mask_out[..., :mask.shape[-1]] out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 9b6b085fd..56e795379 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -234,6 +234,8 @@ def efficient_dot_product_attention( def get_mask_chunk(chunk_idx: int) -> Tensor: if mask is None: return None + if mask.shape[1] == 1: + return mask chunk = min(query_chunk_size, q_tokens) return mask[:,chunk_idx:chunk_idx + chunk] diff --git a/comfy/lora.py b/comfy/lora.py index 295f7d1d8..f1a106656 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -52,6 +52,15 @@ def load_lora(lora, to_load) -> PatchDict: dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) + reshape_name = "{}.reshape_weight".format(x) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) diffusers2_lora = "{}.lora_B.weight".format(x) @@ -82,7 +91,7 @@ def load_lora(lora, to_load) -> PatchDict: if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -191,6 +200,12 @@ def load_lora(lora, to_load) -> PatchDict: patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) + set_weight_name = "{}.set_weight".format(x) + set_weight = lora.get(set_weight_name, None) + if set_weight is not None: + patch_dict[to_load[x]] = ("set", (set_weight,)) + loaded_keys.add(set_weight_name) + for x in lora.keys(): if x not in loaded_keys: logging.warning("lora key not loaded: {}".format(x)) @@ -285,11 +300,14 @@ def model_lora_keys_unet(model, key_map=None): sdk = sd.keys() for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: TODO put lora key prefix in the model config - key_map["{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names + if k.startswith("diffusion_model."): + if k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: TODO put lora key prefix in the model config + key_map["{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names + else: + key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: @@ -445,10 +463,17 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype)) + elif patch_type == "set": + weight.copy_(v[0]) elif patch_type == "lora": # lora/locon mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype) dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + if v[2] is not None: alpha = v[2] / mat2.shape[0] else: diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py new file mode 100644 index 000000000..05032c690 --- /dev/null +++ b/comfy/lora_convert.py @@ -0,0 +1,17 @@ +import torch + + +def convert_lora_bfl_control(sd): #BFL loras for Flux + sd_out = {} + for k in sd: + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + sd_out[k_to] = sd[k] + + sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) + return sd_out + + +def convert_lora(sd): + if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: + return convert_lora_bfl_control(sd) + return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index 6e72522dd..3df551ef6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ from .ldm.aura.mmdit import MMDiT as AuraMMDiT from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_c import StageC from .ldm.flux import model as flux_model +from .ldm.lightricks.model import LTXVModel from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint from .ldm.hydit.models import HunYuanDiT from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper @@ -170,8 +171,7 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None - def extra_conds(self, **kwargs): - out = {} + def concat_cond(self, **kwargs): if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) @@ -210,7 +210,14 @@ class BaseModel(torch.nn.Module): elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) - out['c_concat'] = conds.CONDNoiseShape(data) + return data + return None + + def extra_conds(self, **kwargs): + out = {} + concat_cond = self.concat_cond(**kwargs) + if concat_cond is not None: + out['c_concat'] = conds.CONDNoiseShape(concat_cond) # pylint: disable=assignment-from-none adm = self.encode_adm(**kwargs) @@ -554,9 +561,7 @@ class IP2P(BaseModel): def process_ip2p_image_in(self, image): return None - def extra_conds(self, **kwargs): - out = {} - + def concat_cond(self, **kwargs): image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) device = kwargs["device"] @@ -568,14 +573,8 @@ class IP2P(BaseModel): image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) + return self.process_ip2p_image_in(image) - out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image)) - - # pylint: disable=assignment-from-none - adm = self.encode_adm(**kwargs) - if adm is not None: - out['y'] = conds.CONDRegular(adm) - return out class SD15_instructpix2pix(IP2P, BaseModel): @@ -746,6 +745,38 @@ class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux) + def concat_cond(self, **kwargs): + num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size) + out_channels = self.model_config.unet_config["out_channels"] + + if num_channels <= out_channels: + return None + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.resize_to_batch_size(image, noise.shape[0]) + image = self.process_latent_in(image) + if num_channels <= out_channels * 2: + return image + + #inpaint model + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.ones_like(noise)[:, :1] + + mask = torch.mean(mask, dim=1, keepdim=True) + print(mask.shape) + mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") + mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return torch.cat((image, mask), dim=1) + def encode_adm(self, **kwargs): return kwargs["pooled_output"] @@ -771,3 +802,23 @@ class GenmoMochi(BaseModel): if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) return out + +class LTXV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=LTXVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDRegular(cross_attn) + + guiding_latent = kwargs.get("guiding_latent", None) + if guiding_latent is not None: + out['guiding_latent'] = conds.CONDRegular(guiding_latent) + + out['frame_rate'] = conds.CONDConstant(kwargs.get("frame_rate", 25)) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4e959238e..a27070a16 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -141,6 +141,12 @@ def detect_unet_config(state_dict, key_prefix): dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 + patch_size = 2 + dit_config["patch_size"] = patch_size + in_key = "{}img_in.weight".format(key_prefix) + if in_key in state_dict_keys: + dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) + dit_config["out_channels"] = 16 dit_config["vec_in_dim"] = 768 dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 @@ -181,6 +187,11 @@ def detect_unet_config(state_dict, key_prefix): dit_config["rope_theta"] = 10000.0 return dit_config + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv + dit_config = {} + dit_config["image_model"] = "ltxv" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d45a360e6..e8967e9a3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -419,14 +419,23 @@ class ModelPatcher(ModelManageable): lowvram_counter = 0 loading = [] for n, m in self.model.named_modules(): - if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): - loading.append((model_management.module_size(m), n, m)) + params = [] + skip = False + for name, param in m.named_parameters(recurse=False): + params.append(name) + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + loading.append((model_management.module_size(m), n, m, params)) load_completely = [] loading.sort(reverse=True) for x in loading: n = x[1] m = x[2] + params = x[3] module_mem = x[0] lowvram_weight = False @@ -462,22 +471,22 @@ class ModelPatcher(ModelManageable): if m.comfy_cast_weights: wipe_lowvram_weight(m) - if hasattr(m, "weight"): + if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem - load_completely.append((module_mem, n, m)) + load_completely.append((module_mem, n, m, params)) load_completely.sort(reverse=True) for x in load_completely: n = x[1] m = x[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + params = x[3] if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: continue - self.patch_weight_to_device(weight_key, device_to=device_to) - self.patch_weight_to_device(bias_key, device_to=device_to) + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + logger.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index e95e9e45f..7b384fb5e 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -300,7 +300,8 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64): if tile_size < overlap * 4: overlap = tile_size // 4 - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + compression = vae.spacial_compression_decode() + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -381,6 +382,7 @@ class InpaintModelConditioning: "vae": ("VAE", ), "pixels": ("IMAGE", ), "mask": ("MASK", ), + "noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}), }} RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") @@ -389,7 +391,7 @@ class InpaintModelConditioning: CATEGORY = "conditioning/inpaint" - def encode(self, positive, negative, pixels, vae, mask): + def encode(self, positive, negative, pixels, vae, mask, noise_mask): x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") @@ -413,7 +415,8 @@ class InpaintModelConditioning: out_latent = {} out_latent["samples"] = orig_latent - out_latent["noise_mask"] = mask + if noise_mask: + out_latent["noise_mask"] = mask out = [] for conditioning in [positive, negative]: @@ -924,7 +927,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -943,6 +946,8 @@ class CLIPLoader: clip_type = sd.CLIPType.STABLE_AUDIO elif type == "mochi": clip_type = sd.CLIPType.MOCHI + elif type == "ltxv": + clip_type = comfy.sd.CLIPType.LTXV else: logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}") diff --git a/comfy/sd.py b/comfy/sd.py index 6afa65f50..07f1b4ab2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,13 +6,14 @@ import os.path from enum import Enum from typing import Any, Optional +import comfy.ldm.flux.redux +import comfy.text_encoders.lt import torch import yaml from . import clip_vision from . import diffusers_convert from . import gligen -from . import lora from . import model_detection from . import model_management from . import model_patcher @@ -23,19 +24,20 @@ from . import utils from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder -from .ldm.genmo.vae.model import VideoVAE +from .ldm.genmo.vae import model as genmo +from .ldm.lightricks.vae import causal_video_autoencoder as lightricks from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .model_management import load_models_gpu from .t2i_adapter import adapter from .taesd import taesd from .text_encoders import aura_t5 from .text_encoders import flux +from .text_encoders import genmo from .text_encoders import hydit from .text_encoders import long_clipl from .text_encoders import sa_t5 from .text_encoders import sd2_clip from .text_encoders import sd3_clip -from .text_encoders import genmo def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): @@ -45,6 +47,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): if clip is not None: key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + lora = comfy.lora_convert.convert_lora(lora) loaded = lora.load_lora(_lora, key_map) if model is not None: new_modelpatcher = model.clone() @@ -257,18 +260,26 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae + elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: # genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = utils.state_dict_prefix_replace(sd, {"": "decoder."}) if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd: sd = utils.state_dict_prefix_replace(sd, {"": "encoder."}) - self.first_stage_model = VideoVAE() + self.first_stage_model = genmo.VideoVAE() self.latent_channels = 12 self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.working_dtypes = [torch.float16, torch.float32] + elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: # lightricks ltxv + self.first_stage_model = lightricks.VideoVAE() + self.latent_channels = 128 + self.latent_dim = 3 + self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) + self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -359,7 +370,7 @@ class VAE: out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) - pixel_samples[x:x+batch_number] = out + pixel_samples[x:x + batch_number] = out except model_management.OOM_EXCEPTION as e: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 @@ -368,13 +379,15 @@ class VAE: elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: - pixel_samples = self.decode_tiled_3d(samples_in) + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1) return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None): - memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile + memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) # TODO: calculate mem required for tile load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 args = {} @@ -434,6 +447,12 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() + def spacial_compression_decode(self): + try: + return self.upscale_ratio[-1] + except: + return self.upscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): @@ -448,6 +467,8 @@ def load_style_model(ckpt_path): keys = model_data.keys() if "style_embedding" in keys: model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + elif "redux_down.weight" in keys: + model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data) @@ -462,6 +483,7 @@ class CLIPType(Enum): HUNYUAN_DIT = 5 FLUX = 6 MOCHI = 7 + LTXV = 8 @dataclasses.dataclass @@ -552,7 +574,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.SD3: clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data)) clip_target.tokenizer = sd3_clip.SD3Tokenizer - else: #CLIPType.MOCHI + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer + else: # CLIPType.MOCHI clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = genmo.MochiT5Tokenizer elif te_model == TEModel.T5_XL: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 597ef5e86..809b295ff 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -11,6 +11,7 @@ from .text_encoders import aura_t5 from .text_encoders import hydit from .text_encoders import flux from .text_encoders import genmo +from .text_encoders import lt from . import supported_models_base from . import latent_formats @@ -704,7 +705,34 @@ class GenmoMochi(supported_models_base.BASE): t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect)) +class LTXV(supported_models_base.BASE): + unet_config = { + "image_model": "ltxv", + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi] + sampling_settings = { + "shift": 2.37, + } + + unet_extra_config = {} + latent_format = latent_formats.LTXV + + memory_usage_factor = 2.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXV(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(lt.LTXVT5Tokenizer, lt.ltxv_te(**t5_detect)) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV] models += [SVD_img2vid] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py new file mode 100644 index 000000000..a18112019 --- /dev/null +++ b/comfy/text_encoders/lt.py @@ -0,0 +1,24 @@ +from transformers import T5TokenizerFast + +from comfy import sd1_clip +from .genmo import mochi_te +from ..component_model import files + + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) # pad to 128? + + +class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + + +def ltxv_te(*args, **kwargs): + return mochi_te(*args, **kwargs) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py new file mode 100644 index 000000000..9d0639378 --- /dev/null +++ b/comfy_extras/nodes_lt.py @@ -0,0 +1,181 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.model_sampling +import math + +class EmptyLTXVLatentVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/video/ltxv" + + def generate(self, width, height, length, batch_size=1): + latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return ({"samples": latent}, ) + + +class LTXVImgToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE",), + "image": ("IMAGE",), + "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + CATEGORY = "conditioning/video_models" + FUNCTION = "generate" + + def generate(self, positive, negative, image, vae, width, height, length, batch_size): + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t}) + negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t}) + + latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + latent[:, :, :t.shape[2]] = t + return (positive, negative, {"samples": latent}, ) + + +class LTXVConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "append" + + CATEGORY = "conditioning/video_models" + + def append(self, positive, negative, frame_rate): + positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) + negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) + return (positive, negative) + + +class ModelSamplingLTXV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), + }, + "optional": {"latent": ("LATENT",), } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, max_shift, base_shift, latent=None): + m = model.clone() + + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (tokens) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + +class LTXVScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), + "stretch": ("BOOLEAN", { + "default": True, + "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." + }), + "terminal": ( + "FLOAT", + { + "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, + "tooltip": "The terminal value of the sigmas after stretching." + }, + ), + }, + "optional": {"latent": ("LATENT",), } + } + + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + sigmas = torch.linspace(1.0, 0.0, steps + 1) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = (tokens) * mm + b + + power = 1 + sigmas = torch.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + # Stretch sigmas so that its final value matches the given terminal value. + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + return (sigmas,) + + +NODE_CLASS_MAPPINGS = { + "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, + "LTXVImgToVideo": LTXVImgToVideo, + "ModelSamplingLTXV": ModelSamplingLTXV, + "LTXVConditioning": LTXVConditioning, + "LTXVScheduler": LTXVScheduler, +}