diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 099732a6c..5a717bf38 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -90,7 +90,8 @@ class TimestepEmbedder(nn.Module): max_period=10000, out_size=None, dtype=None, - device=None + device=None, + operations = None ): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() @@ -100,9 +101,9 @@ class TimestepEmbedder(nn.Module): out_size = hidden_size self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), act_layer(), - nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) def forward(self, t): @@ -156,7 +157,6 @@ def build_2d_rope( image_infos_list = [image_infos] sample_seq_lens = [seq_len] - # Prepare position indices for each sample x_sections = [] y_sections = [] for sample_id, sample_image_infos in enumerate(image_infos_list): @@ -168,13 +168,10 @@ def build_2d_rope( y_sections.append(torch.arange(last_pos, L)) x_sections.append(torch.arange(last_pos, L)) elif h is None: - # Interleave data has overlapped positions for tokens. y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) continue else: - # Interleave data has overlapped positions for noised image and the successive clean image, - # leading to last_pos (= last text end L + noise w * h) > L (last text end L). pass # current image beta_y = L + (w * h - h) / 2 @@ -209,7 +206,7 @@ def build_2d_rope( def build_batch_2d_rope( - seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, + seq_len: int, n_elem: int, image_infos = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): @@ -261,17 +258,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def default(val, d): return val if val is not None else d -def conv_nd(dims, *args, **kwargs): - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - -def normalization(channels, **kwargs): - return nn.GroupNorm(32, channels, **kwargs) - def topkgating(logits: torch.Tensor, topk: int): logits = logits.float() gates = F.softmax(logits, dim=1) @@ -311,9 +297,9 @@ def topkgating(logits: torch.Tensor, topk: int): return combine_weights, dispatch_mask class HunyuanRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None): super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -326,7 +312,7 @@ class HunyuanRMSNorm(nn.Module): class UNetDown(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, - dropout=0.0, device=None, dtype=None): + dropout=0.0, device=None, dtype=None, operations=None): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() @@ -334,8 +320,7 @@ class UNetDown(nn.Module): assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList( - [conv_nd( - 2, + [operations.Conv2d( in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, @@ -351,7 +336,8 @@ class UNetDown(nn.Module): out_channels=out_channels, use_scale_shift_norm = True, dropout=dropout, - **factory_kwargs + **factory_kwargs, + operations = operations )) else: for i in range(self.patch_size // 2): @@ -362,7 +348,8 @@ class UNetDown(nn.Module): out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, down=True, - **factory_kwargs + **factory_kwargs, + operations = operations )) def forward(self, x, t): @@ -397,7 +384,8 @@ class UNetUp(nn.Module): out_channels=hidden_channels, use_scale_shift_norm = True, dropout=dropout, - **factory_kwargs + **factory_kwargs, + operations = operations )) else: for i in range(self.patch_size // 2): @@ -408,14 +396,15 @@ class UNetUp(nn.Module): use_scale_shift_norm = True, dropout=dropout, up=True, - **factory_kwargs + **factory_kwargs, + operations = operations )) if out_norm: self.model.append(nn.Sequential( - normalization(hidden_channels, **factory_kwargs), + operations.GroupNorm(32, hidden_channels, **factory_kwargs), nn.SiLU(), - nn.Conv2d( + operations.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, @@ -424,7 +413,7 @@ class UNetUp(nn.Module): ), )) else: - self.model.append(nn.Conv2d( + self.model.append(operations.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, @@ -443,14 +432,14 @@ class UNetUp(nn.Module): return x class HunyuanTopKGate(nn.Module): - def __init__(self, config, layer_idx: Optional[int] = None): + def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.min_capacity = 8 num_experts = 64 - self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) + self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device) def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape @@ -463,7 +452,7 @@ class HunyuanTopKGate(nn.Module): return gate_output class HunyuanMLP(nn.Module): - def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): + def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -473,8 +462,8 @@ class HunyuanMLP(nn.Module): self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU - self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) - self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) + self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: @@ -494,7 +483,9 @@ class MoELRUCache(nn.Module): self.last_offload_event = None self._loop = asyncio.new_event_loop() - self._gpu_sem = asyncio.Semaphore(1) # maybe 2 + self._gpu_sem = asyncio.Semaphore(2) + self.operations = None + self.dtype = None threading.Thread(target=self._loop.run_forever, daemon=True).start() async def _async_offload_to_cpu(self, layer_idx): @@ -509,12 +500,12 @@ class MoELRUCache(nn.Module): with torch.cuda.stream(self.offload_stream): for index, moe in moe_group: - moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations) for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): if p_gpu.device.type == "meta": continue with torch.no_grad(): - p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True) p_cpu.copy_(p_gpu, non_blocking=True) self.cpu_cache[index] = moe_cpu @@ -547,7 +538,7 @@ class MoELRUCache(nn.Module): # async loading from cpu -> gpu with torch.no_grad(): with torch.cuda.stream(self.load_stream): - moe_gpu = HunyuanMLP(moe.config, device="meta") + moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations) for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): p_gpu.data.copy_(p_cpu, non_blocking=True) @@ -661,17 +652,17 @@ def enough_vram(required_bytes): return free > required_bytes class HunyuanMoE(nn.Module): - def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): + def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.num_experts = 64 - self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) - self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) + self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) + self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) if INIT_MOE: self.experts = nn.ModuleList( - [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] + [HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)] ) else: self.experts = [] @@ -757,7 +748,7 @@ class HunyuanMoE(nn.Module): return output class HunyuanImage3Attention(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -775,15 +766,15 @@ class HunyuanImage3Attention(nn.Module): self.hidden_size_kv = self.head_dim * self.num_key_value_heads # define layers - self.qkv_proj = nn.Linear( + self.qkv_proj = operations.Linear( self.hidden_size, self.hidden_size_q + 2 * self.hidden_size_kv, - bias=False + bias=False, device=device, dtype=dtype ) - self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False) + self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype) - self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) - self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) + self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -838,16 +829,16 @@ class HunyuanImage3Attention(nn.Module): return attn_output, past_key_value class HunyuanImage3DecoderLayer(nn.Module): - def __init__(self, config, layer_idx: int, moe_lru=None): + def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.hidden_size = config["hidden_size"] self.layer_idx = layer_idx - self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx) + self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) - self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru) + self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations) - self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) - self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps']) + self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype) def forward( self, @@ -887,14 +878,14 @@ class HunyuanImage3DecoderLayer(nn.Module): return outputs class HunyuanImage3Model(nn.Module): - def __init__(self, config, moe_lru=None): + def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 self.config = config - self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) + self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype) self.layers = nn.ModuleList( - [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] + [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])] ) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) @@ -979,9 +970,11 @@ class HunyuanImage3Model(nn.Module): class HunyuanImage3ForCausalMM(nn.Module): - def __init__(self, config): + def __init__(self, operations = None, dtype = None, device = None, **kwargs): super().__init__() + config = kwargs self.config = config + factory_kwargs = {"device": device, "dtype": dtype, "operations": operations} self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) self.patch_embed = UNetDown( @@ -990,8 +983,9 @@ class HunyuanImage3ForCausalMM(nn.Module): in_channels=32, hidden_channels=1024, out_channels=config["hidden_size"], + **factory_kwargs ) - self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) + self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.final_layer = UNetUp( patch_size=1, @@ -1000,19 +994,22 @@ class HunyuanImage3ForCausalMM(nn.Module): hidden_channels=1024, out_channels=32, out_norm=True, + **factory_kwargs ) - self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"]) + self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.moe_lru = None if not INIT_MOE: self.moe_lru = MoELRUCache() + self.moe_lru.operations = operations + self.moe_lru.dtype = dtype - self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru) + self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs) self.pad_id = 128009 self.vocab_size = 133120 - self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False) + self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype) self.first_step = True self.kv_cache = None diff --git a/comfy/utils.py b/comfy/utils.py index 4bd281057..0ddfa3c0e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -63,6 +63,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): + if k.startswith("__SKIP__"): + continue tensor = f.get_tensor(k) if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues tensor = tensor.to(device=device, copy=True)