diff --git a/CODEOWNERS b/CODEOWNERS index 72a59effe..013ea8622 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,20 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink +*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne # Python web server -/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata +/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne # Node developers -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered -/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered +/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne diff --git a/README.md b/README.md index a99aca0e7..cf6df7e55 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) + - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) diff --git a/app/frontend_management.py b/app/frontend_management.py index c56ea86e0..7b7923b79 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -184,6 +184,27 @@ comfyui-frontend-package is not installed. ) sys.exit(-1) + @classmethod + def templates_path(cls) -> str: + try: + import comfyui_workflow_templates + + return str( + importlib.resources.files(comfyui_workflow_templates) / "templates" + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 3535966fb..42ed5174e 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes """ - default: bool | str | float | int | list | tuple + default: NotRequired[bool | str | float | int | list | tuple] """The default value of the widget""" - defaultInput: bool + defaultInput: NotRequired[bool] """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. - defaultInput on required inputs should be dropped. - defaultInput on optional inputs should be replaced with forceInput. Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 """ - forceInput: bool + forceInput: NotRequired[bool] """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" - lazy: bool + lazy: NotRequired[bool] """Declares that this input uses lazy evaluation""" - rawLink: bool + rawLink: NotRequired[bool] """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" - tooltip: str + tooltip: NotRequired[str] """Tooltip for the input (or widget), shown on pointer hover""" # class InputTypeNumber(InputTypeOptions): # default: float | int - min: float + min: NotRequired[float] """The minimum value of a number (``FLOAT`` | ``INT``)""" - max: float + max: NotRequired[float] """The maximum value of a number (``FLOAT`` | ``INT``)""" - step: float + step: NotRequired[float] """The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)""" - round: float + round: NotRequired[float] """Floats are rounded by this value (``FLOAT``)""" # class InputTypeBoolean(InputTypeOptions): # default: bool - label_on: str + label_on: NotRequired[str] """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_off: str + label_off: NotRequired[str] """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str - multiline: bool + multiline: NotRequired[bool] """Use a multiline text box (``STRING``)""" - placeholder: str + placeholder: NotRequired[str] """Placeholder text to display in the UI when empty (``STRING``)""" # Deprecated: # defaultVal: str - dynamicPrompts: bool + dynamicPrompts: NotRequired[bool] """Causes the front-end to evaluate dynamic prompts (``STRING``)""" # class InputTypeCombo(InputTypeOptions): - image_upload: bool + image_upload: NotRequired[bool] """Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`.""" - image_folder: Literal["input", "output", "temp"] + image_folder: NotRequired[Literal["input", "output", "temp"]] """Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """ - remote: RemoteInputOptions + remote: NotRequired[RemoteInputOptions] """Specifies the configuration for a remote input. Available after ComfyUI frontend v1.9.7 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422""" - control_after_generate: bool + control_after_generate: NotRequired[bool] """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" options: NotRequired[list[str | int | float]] """COMBO type only. Specifies the selectable options for the combo widget. @@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict): class HiddenInputTypeDict(TypedDict): """Provides type hinting for the hidden entry of node INPUT_TYPES.""" - node_id: Literal["UNIQUE_ID"] + node_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - unique_id: Literal["UNIQUE_ID"] + unique_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - prompt: Literal["PROMPT"] + prompt: NotRequired[Literal["PROMPT"]] """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" - extra_pnginfo: Literal["EXTRA_PNGINFO"] + extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]] """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" - dynprompt: Literal["DYNPROMPT"] + dynprompt: NotRequired[Literal["DYNPROMPT"]] """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" @@ -187,11 +187,11 @@ class InputTypeDict(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs """ - required: dict[str, tuple[IO, InputTypeOptions]] + required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes all inputs that must be connected for the node to execute.""" - optional: dict[str, tuple[IO, InputTypeOptions]] + optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes inputs which do not need to be connected.""" - hidden: HiddenInputTypeDict + hidden: NotRequired[HiddenInputTypeDict] """Offers advanced functionality and server-client communication. Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py new file mode 100644 index 000000000..fcb5a9c51 --- /dev/null +++ b/comfy/ldm/hidream/model.py @@ -0,0 +1,799 @@ +from typing import Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +import torch.nn.functional as F + +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.flux.layers import LastLayer + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +import comfy.ldm.common_dit + + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + dtype=None, device=None, operations=None + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, latent): + latent = self.proj(latent) + return latent + + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) + + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + +class HiDreamAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False, + dtype=None, device=None, operations=None + ): + # super(Attention, self).__init__() + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = operations.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + self.processor = processor + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + dtype=None, device=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device)) + self.reset_parameters() + + def reset_parameters(self) -> None: + pass + # import torch.nn.init as init + # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + dtype=None, device=None, operations=None + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if True: # self.training: # TODO: check which branch performs faster + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = torch.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + #y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + + +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device) + ) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + + +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device) + ) + # nn.init.zeros_(self.adaLN_modulation[1].weight) + # nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + + +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + dtype=None, device=None, operations=None + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + + +class HiDreamImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + image_model=None, + dtype=None, device=None, operations=None + ): + self.patch_size = patch_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.num_layers = num_layers + self.num_single_layers = num_single_layers + + self.gradient_checkpointing = False + + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + dtype=dtype, device=device, operations=operations + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_single_layers) + ] + ) + + self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]: + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.patch_size, p2=self.patch_size) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.patch_size * self.patch_size + if isinstance(x, torch.Tensor): + B = x.shape[0] + device = x.device + dtype = x.dtype + else: + B = len(x) + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + control = None, + transformer_options = {}, + ) -> torch.Tensor: + bs, c, h, w = x.shape + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + timesteps = t + pooled_embeds = y + T5_encoder_hidden_states = context + + img_sizes = None + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + # T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states = block( + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=None, + adaln_input=adaln_input, + rope=rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes) + return -output[:, :, :h, :w] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9b5e5332c..2a30497c5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=257, ): r""" Args: @@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x @@ -250,7 +251,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -284,6 +293,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -373,7 +383,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None @@ -420,9 +430,12 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -430,12 +443,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6bc627ae3..8dab1740b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model +import comfy.ldm.hidream.model import comfy.model_management import comfy.patcher_extension @@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HiDream(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + conditioning_llama3 = kwargs.get("conditioning_llama3", None) + if conditioning_llama3 is not None: + out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4217f5831..6499bf238 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D @@ -338,6 +341,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream + dit_config = {} + dit_config["image_model"] = "hidream" + dit_config["attention_head_dim"] = 128 + dit_config["axes_dims_rope"] = [64, 32, 32] + dit_config["caption_channels"] = [4096, 4096] + dit_config["max_resolution"] = [128, 128] + dit_config["in_channels"] = 16 + dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31] + dit_config["num_attention_heads"] = 20 + dit_config["num_routed_experts"] = 4 + dit_config["num_activated_experts"] = 2 + dit_config["num_layers"] = 16 + dit_config["num_single_layers"] = 32 + dit_config["out_channels"] = 16 + dit_config["patch_size"] = 2 + dit_config["text_emb_dim"] = 2048 + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/ops.py b/comfy/ops.py index 6b0e29307..aae6cafac 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -263,6 +263,9 @@ class manual_cast(disable_weight_init): class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True + class RMSNorm(disable_weight_init.RMSNorm): + comfy_cast_weights = True + class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 81b3e9062..9d82bee1a 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -27,17 +27,6 @@ def rms_norm(x, weight=None, eps=1e-6): if RMSNorm is None: class RMSNorm(torch.nn.Module): - def __init__( - self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs - ): - super().__init__() - self.eps = eps - self.learnable_scale = elementwise_affine - if self.learnable_scale: - self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) - else: - self.register_parameter("weight", None) - def __init__( self, normalized_shape, @@ -60,6 +49,7 @@ if RMSNorm is None: ) else: self.register_parameter("weight", None) + self.bias = None def forward(self, x): return rms_norm(x, self.weight, self.eps) diff --git a/comfy/sd.py b/comfy/sd.py index 4d3aef3e1..8aba5d655 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.hidream import comfy.model_patcher import comfy.lora @@ -702,6 +703,7 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + HIDREAM = 14 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -790,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -810,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer @@ -826,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.LLAMA3_8: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: + # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -847,12 +864,33 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO: clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer + elif clip_type == CLIPType.HIDREAM: + # Detect + hidream_dualclip_classes = [] + for hidream_te in clip_data: + te_model = detect_te_model(hidream_te) + hidream_dualclip_classes.append(te_model) + + clip_l = TEModel.CLIP_L in hidream_dualclip_classes + clip_g = TEModel.CLIP_G in hidream_dualclip_classes + t5 = TEModel.T5_XXL in hidream_dualclip_classes + llama = TEModel.LLAMA3_8 in hidream_dualclip_classes + + # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} + llama_kwargs = llama_detect(clip_data) if llama else {} + + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif len(clip_data) == 4: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer parameters = 0 for c in clip_data: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index be21ec18d..2ca5ed9ba 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): LAYERS = [ "last", "pooled", - "hidden" + "hidden", + "all" ] def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, @@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if "model_name" not in model_options: + model_options = {**model_options, "model_name": "clip_l"} if isinstance(textmodel_json_config, dict): config = textmodel_json_config @@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) + te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {}) + for k, v in te_model_options.items(): + config[k] = v + operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: + if self.layer == "all": + pass + elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + if self.layer == "all": + intermediate_output = "all" + else: + intermediate_output = self.layer_idx + + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() @@ -447,7 +461,7 @@ class SDTokenizer: if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) - self.max_length = max_length + self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None @@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) clip_model = model_options.get("{}_class".format(self.clip), clip_model) + model_options = {**model_options, "model_name": self.clip} setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 5b7c8a412..ea7f5d10f 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) @@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -41,8 +41,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2a6a61560..81c47ac68 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +class HiDream(supported_models_base.BASE): + unet_config = { + "image_model": "hidream", + } + + sampling_settings = { + "shift": 3.0, + } + + sampling_settings = { + } + + # memory_usage_factor = 1.2 # TODO + + unet_extra_config = {} + latent_format = latent_formats.Flux + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HiDream(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index e9ad45a7f..cf4252eea 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel): class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data) class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 5441c8952..a1adb5242 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data) class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index a12995ec0..0666dde7f 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -9,14 +9,13 @@ import os class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a480..9dcf190a2 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py new file mode 100644 index 000000000..8e1abcfc1 --- /dev/null +++ b/comfy/text_encoders/hidream.py @@ -0,0 +1,155 @@ +from . import hunyuan_video +from . import sd3_clip +from comfy import sd1_clip +from comfy import sdxl_clip +import comfy.model_management +import torch +import logging + + +class HiDreamTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data) + self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return self.clip_g.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class HiDreamTEModel(torch.nn.Module): + def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + if clip_l: + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_g = None + + if t5: + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) + self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True) + self.dtypes.add(dtype_t5) + else: + self.t5xxl = None + + if llama: + dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) + if "vocab_size" not in model_options: + model_options["vocab_size"] = 128256 + self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009}) + self.dtypes.add(dtype_llama) + else: + self.llama = None + + logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama)) + + def set_clip_options(self, options): + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) + if self.llama is not None: + self.llama.set_clip_options(options) + + def reset_clip_options(self): + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() + if self.llama is not None: + self.llama.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_g = token_weight_pairs["g"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_llama = token_weight_pairs["llama"] + lg_out = None + pooled = None + extra = {} + + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + else: + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_out, t5_pooled = t5_output[:2] + else: + t5_out = None + + if self.llama is not None: + ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) + ll_out, ll_pooled = ll_output[:2] + ll_out = ll_out[:, 1:] + else: + ll_out = None + + if t5_out is None: + t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) + + if ll_out is None: + ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) + + extra["conditioning_llama3"] = ll_out + return t5_out, pooled, extra + + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: + return self.t5xxl.load_sd(sd) + else: + return self.llama.load_sd(sd) + + +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): + class HiDreamTEModel_(HiDreamTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index dbb259e54..33ac22497 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data) class LLAMAModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) if llama_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = llama_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + textmodel_json_config = {} + vocab_size = model_options.get("vocab_size", None) + if vocab_size is not None: + textmodel_json_config["vocab_size"] = vocab_size + + model_options = {**model_options, "model_name": "llama"} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens - self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) + self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} @@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) self.dtypes = set([dtype, dtype_llama]) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7da3e9fc5..e7273f425 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -9,24 +9,26 @@ import torch class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") + model_options = {**model_options, "model_name": "hydit_clip"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data) class MT5XLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") + model_options = {**model_options, "model_name": "mt5xl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model") tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -35,7 +37,7 @@ class HyditTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) - self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) + self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 58710b2bf..34eb870e3 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -268,11 +268,17 @@ class Llama2_(nn.Module): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output for i, layer in enumerate(self.layers): + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -283,6 +289,12 @@ class Llama2_(nn.Module): intermediate = x.clone() x = self.norm(x) + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + if intermediate is not None and final_layer_norm_intermediate: intermediate = self.norm(intermediate) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index b81912cb3..8d4c7619d 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,30 +1,27 @@ -from comfy import sd1_clip -import os -class LongClipTokenizer_(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - -class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, *args, **kwargs): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) - -class LongClipTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) - -class LongClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) def model_options_long_clip(sd, tokenizer_data, model_options): w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None) + else: + model_name = "clip_g" + if w is None: w = sd.get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: + if w is not None: + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + model_name = "clip_g" + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + model_name = "clip_l" + else: + model_name = "clip_l" + + if w is not None: tokenizer_data = tokenizer_data.copy() model_options = model_options.copy() - tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ - model_options["clip_l_class"] = LongClipModel_ + model_config = model_options.get("model_config", {}) + model_config["max_position_embeddings"] = w.shape[0] + model_options["{}_model_config".format(model_name)] = model_config + tokenizer_data["{}_max_length".format(model_name)] = w.shape[0] return tokenizer_data, model_options diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5c2ce583f..48ea67e67 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,7 +6,7 @@ import comfy.text_encoders.genmo class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128? class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index a7b1d702b..674461b75 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -6,7 +6,7 @@ import comfy.text_encoders.llama class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index d56d57f1b..b8de6bc4e 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 7778ce47a..2803926ac 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel): class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 31fc89869..700a23bf0 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 3ad2ed93a..6c2fbeca4 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 21df4f863..caccb3ca2 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,4 +1,5 @@ import torch +import os class SPieceTokenizer: @staticmethod @@ -15,6 +16,8 @@ class SPieceTokenizer: if isinstance(tokenizer_path, bytes): self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: + if not os.path.isfile(tokenizer_path): + raise ValueError("invalid tokenizer") self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 971ac8fa8..d50fa4b28 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 59b42b746..a2799b52e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,9 @@ -import nodes +from __future__ import annotations +from typing import Type, Literal +import nodes from comfy_execution.graph_utils import is_link +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions class DependencyCycleError(Exception): pass @@ -54,7 +57,22 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name, valid_inputs=None): +def get_input_info( + class_def: Type[ComfyNodeABC], + input_name: str, + valid_inputs: InputTypeDict | None = None +) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + """Get the input type, category, and extra info for a given input name. + + Arguments: + class_def: The class definition of the node. + input_name: The name of the input to get info for. + valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES(). + + Returns: + tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name. + """ + valid_inputs = valid_inputs or class_def.INPUT_TYPES() input_info = None input_category = None @@ -126,7 +144,7 @@ class TopologicalSort: from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 000000000..ee310c874 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,100 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # Initialize mask with high-frequency scaling factor + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) + + # Apply low-frequency scaling factor to center region + m[:] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py new file mode 100644 index 000000000..5a160c2ba --- /dev/null +++ b/comfy_extras/nodes_hidream.py @@ -0,0 +1,32 @@ +import folder_paths +import comfy.sd +import comfy.model_management + + +class QuadrupleCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" + + def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + + +NODE_CLASS_MAPPINGS = { + "QuadrupleCLIPLoader": QuadrupleCLIPLoader, +} diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index db30030fb..53d892bc4 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -21,8 +21,8 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -41,7 +41,7 @@ class Load3D(): normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, normal_image, lineart_image + return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'] class Load3DAnimation(): @classmethod @@ -59,8 +59,8 @@ class Load3DAnimation(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -77,13 +77,16 @@ class Load3DAnimation(): ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, normal_image + return output_image, output_mask, model_file, normal_image, image['camera_info'] class Preview3D(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -95,13 +98,22 @@ class Preview3D(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } class Preview3DAnimation(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -113,7 +125,13 @@ class Preview3DAnimation(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2d0f31ac8..8ad358ce8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -4,6 +4,7 @@ import torch import comfy.model_management import comfy.utils import comfy.latent_formats +import comfy.clip_vision class WanImageToVideo: @@ -99,6 +100,72 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFirstLastFrameToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), + "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_start_image is not None: + clip_vision_output = clip_vision_start_image + + if clip_vision_end_image is not None: + if clip_vision_output is not None: + states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2) + clip_vision_output = comfy.clip_vision.Output() + clip_vision_output.penultimate_hidden_states = states + else: + clip_vision_output = clip_vision_end_image + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + class WanFunInpaintToVideo: @classmethod def INPUT_TYPES(s): @@ -122,38 +189,13 @@ class WanFunInpaintToVideo: CATEGORY = "conditioning/video_models" def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - if start_image is not None: - start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - if end_image is not None: - end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + flfv = WanFirstLastFrameToVideo() + return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) - image = torch.ones((length, height, width, 3)) * 0.5 - mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) - - if start_image is not None: - image[:start_image.shape[0]] = start_image - mask[:, :, :start_image.shape[0] + 3] = 0.0 - - if end_image is not None: - image[-end_image.shape[0]:] = end_image - mask[:, :, -end_image.shape[0]:] = 0.0 - - concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - out_latent = {} - out_latent["samples"] = latent - return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, + "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, } diff --git a/comfyui_version.py b/comfyui_version.py index a44538d1a..f9161b37e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.28" +__version__ = "0.3.29" diff --git a/execution.py b/execution.py index 9a5e27771..d09102f55 100644 --- a/execution.py +++ b/execution.py @@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + _, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated): continue val = inputs[x] - info = (type_input, extra_info) + info = (input_type, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", @@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated): val = val["__value__"] inputs[x] = val - if type_input == "INT": + if input_type == "INT": val = int(val) inputs[x] = val - if type_input == "FLOAT": + if input_type == "FLOAT": val = float(val) inputs[x] = val - if type_input == "STRING": + if input_type == "STRING": val = str(val) inputs[x] = val - if type_input == "BOOLEAN": + if input_type == "BOOLEAN": val = bool(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", + "message": f"Failed to convert an input value to a {input_type} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, @@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if isinstance(type_input, list): - if val not in type_input: + if isinstance(input_type, list): + combo_options = input_type + if val not in combo_options: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" + if len(combo_options) > 20: + list_info = f"(list of length {len(combo_options)})" input_config = None else: - list_info = str(type_input) + list_info = str(combo_options) error = { "type": "value_not_in_list", diff --git a/nodes.py b/nodes.py index 68505b952..562fbd5df 100644 --- a/nodes.py +++ b/nodes.py @@ -918,7 +918,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -928,29 +928,10 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5" def load_clip(self, clip_name, type="stable_diffusion", device="default"): - if type == "stable_cascade": - clip_type = comfy.sd.CLIPType.STABLE_CASCADE - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "stable_audio": - clip_type = comfy.sd.CLIPType.STABLE_AUDIO - elif type == "mochi": - clip_type = comfy.sd.CLIPType.MOCHI - elif type == "ltxv": - clip_type = comfy.sd.CLIPType.LTXV - elif type == "pixart": - clip_type = comfy.sd.CLIPType.PIXART - elif type == "cosmos": - clip_type = comfy.sd.CLIPType.COSMOS - elif type == "lumina2": - clip_type = comfy.sd.CLIPType.LUMINA2 - elif type == "wan": - clip_type = comfy.sd.CLIPType.WAN - else: - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} if device == "cpu": @@ -965,7 +946,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -975,19 +956,13 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" def load_clip(self, clip_name1, clip_name2, type, device="default"): + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - if type == "sdxl": - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "flux": - clip_type = comfy.sd.CLIPType.FLUX - elif type == "hunyuan_video": - clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO model_options = {} if device == "cpu": @@ -2285,7 +2260,9 @@ def init_builtin_extra_nodes(): "nodes_hunyuan3d.py", "nodes_primitive.py", "nodes_cfg.py", - "nodes_optimalsteps.py" + "nodes_optimalsteps.py", + "nodes_hidream.py", + "nodes_fresca.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 6eb1704db..e8fc9555d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.28" +version = "0.3.29" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 310ab57ba..6459d3353 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -comfyui-frontend-package==1.15.13 +comfyui-frontend-package==1.16.9 +comfyui-workflow-templates==0.1.3 comfyui_manager torch torchsde diff --git a/server.py b/server.py index 62667ce18..0cc97b248 100644 --- a/server.py +++ b/server.py @@ -736,6 +736,12 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + workflow_templates_path = FrontendManager.templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ])