diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -858,3 +859,267 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2e990dd75..59283b300 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1264,6 +1264,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ada3d73a1..1ee6597dc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 @@ -533,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -1061,6 +1083,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') diff --git a/comfy/model_management.py b/comfy/model_management.py index c817d43b5..0e0e96672 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -32,9 +32,6 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.torch -import comfy_aimdo.model_vbar - class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): - if hasattr(weight, "_v"): - #Unexpected usage patterns. There is no reason these don't work but they - #have no testing and no callers do this. - assert r is None - assert stream is None - - cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) - - if dtype is None: - dtype = weight._model_dtype - - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) - if signature is not None: - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - v_tensor = weight._v_tensor - else: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = v_tensor - weight._v_signature = signature - #Send it over - v_tensor.copy_(weight, non_blocking=non_blocking) - return v_tensor.to(dtype=dtype) - - r = torch.empty_like(weight, dtype=dtype, device=device) - - if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: - #Offloaded casting could skip this, however it would make the quantizations - #inconsistent between loaded and offloaded weights. So force the double casting - #that would happen in regular flow to make offload deterministic. - cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) - cast_buffer.copy_(weight, non_blocking=non_blocking) - weight = cast_buffer - r.copy_(weight, non_blocking=non_blocking) - - return r - if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3fc76d9db..e380e406b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1435,10 +1435,6 @@ class ModelPatcherDynamic(ModelPatcher): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): super().__init__(model, load_device, offload_device, size, weight_inplace_update) - #this is now way more dynamic and we dont support the same base model for both Dynamic - #and non-dynamic patchers. - if hasattr(self.model, "model_loaded_weight_memory"): - del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} self.non_dynamic_delegate_model = None @@ -1461,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher): def loaded_size(self): vbar = self._vbar_get() - if vbar is None: - return 0 - return vbar.loaded_size() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory def get_free_memory(self, device): #NOTE: on high condition / batch counts, estimate should have already vacated @@ -1504,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 + self.model.model_loaded_weight_memory = 0 with self.use_ejected(): self.unpatch_hooks() @@ -1512,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We force reserve VRAM for the non comfy-weight so we dont have to deal - #with pin and unpin syncrhonization which can be expensive for small weights - #with a high layer rate (e.g. autoregressive LLMs). - #prioritize the non-comfy weights (note the order reverse). loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) @@ -1558,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1583,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher): for param in params: key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) - weight.seed_key = key - set_dirty(weight, dirty) - geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype - geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - weight_size = geometry.numel() * geometry.element_size() - if vbar is not None and not hasattr(weight, "_v"): - weight._v = vbar.alloc(weight_size) - weight._model_dtype = model_dtype - allocated_size += weight_size - vbar.set_watermark_limit(allocated_size) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() move_weight_functions(m, device_to) - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid @@ -1613,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher): assert self.load_device != torch.device("cpu") vbar = self._vbar_get() - return 0 if vbar is None else vbar.free_memory(memory_to_free) + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed def partially_unload_ram(self, ram_to_unload): loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) @@ -1640,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) - keys = list(self.backup.keys()) - for k in keys: - bk = self.backup[k] - comfy.utils.set_attr_param(self.model, k, bk.weight) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 51db3dba6..81d45203f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1730,6 +1744,6 @@ class RT_DETR_v4(supported_models_base.BASE): out = model_base.RT_DETR_v4(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 43df0512f..5d8d9bf6f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode): def vae_decode_audio(vae, samples, tile=None, overlap=None): if tile is not None: - audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) else: audio = vae.decode(samples["samples"]).movedim(-1, 1) diff --git a/main.py b/main.py index af701f8df..0f58d57b8 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() - if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -28,6 +23,11 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() + if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' diff --git a/requirements.txt b/requirements.txt index 71019c16f..608b0cfa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.4 +comfyui-workflow-templates==0.9.5 comfyui-embedded-docs==0.4.3 torch torchsde