mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 07:10:15 +08:00
operations + device + dtype | checkpoint skip
This commit is contained in:
parent
88c350bfed
commit
76e14d69b2
@ -90,7 +90,8 @@ class TimestepEmbedder(nn.Module):
|
|||||||
max_period=10000,
|
max_period=10000,
|
||||||
out_size=None,
|
out_size=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None
|
device=None,
|
||||||
|
operations = None
|
||||||
):
|
):
|
||||||
factory_kwargs = {'dtype': dtype, 'device': device}
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -100,9 +101,9 @@ class TimestepEmbedder(nn.Module):
|
|||||||
out_size = hidden_size
|
out_size = hidden_size
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
||||||
act_layer(),
|
act_layer(),
|
||||||
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
@ -156,7 +157,6 @@ def build_2d_rope(
|
|||||||
image_infos_list = [image_infos]
|
image_infos_list = [image_infos]
|
||||||
sample_seq_lens = [seq_len]
|
sample_seq_lens = [seq_len]
|
||||||
|
|
||||||
# Prepare position indices for each sample
|
|
||||||
x_sections = []
|
x_sections = []
|
||||||
y_sections = []
|
y_sections = []
|
||||||
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
||||||
@ -168,13 +168,10 @@ def build_2d_rope(
|
|||||||
y_sections.append(torch.arange(last_pos, L))
|
y_sections.append(torch.arange(last_pos, L))
|
||||||
x_sections.append(torch.arange(last_pos, L))
|
x_sections.append(torch.arange(last_pos, L))
|
||||||
elif h is None:
|
elif h is None:
|
||||||
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
|
|
||||||
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
||||||
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# Interleave data has overlapped positions for noised image and the successive clean image,
|
|
||||||
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
|
|
||||||
pass
|
pass
|
||||||
# current image
|
# current image
|
||||||
beta_y = L + (w * h - h) / 2
|
beta_y = L + (w * h - h) / 2
|
||||||
@ -209,7 +206,7 @@ def build_2d_rope(
|
|||||||
|
|
||||||
|
|
||||||
def build_batch_2d_rope(
|
def build_batch_2d_rope(
|
||||||
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
|
seq_len: int, n_elem: int, image_infos = None,
|
||||||
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
||||||
return_all_pos: bool = False,
|
return_all_pos: bool = False,
|
||||||
):
|
):
|
||||||
@ -261,17 +258,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if val is not None else d
|
return val if val is not None else d
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
|
||||||
if dims == 1:
|
|
||||||
return nn.Conv1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.Conv2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.Conv3d(*args, **kwargs)
|
|
||||||
|
|
||||||
def normalization(channels, **kwargs):
|
|
||||||
return nn.GroupNorm(32, channels, **kwargs)
|
|
||||||
|
|
||||||
def topkgating(logits: torch.Tensor, topk: int):
|
def topkgating(logits: torch.Tensor, topk: int):
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
gates = F.softmax(logits, dim=1)
|
gates = F.softmax(logits, dim=1)
|
||||||
@ -311,9 +297,9 @@ def topkgating(logits: torch.Tensor, topk: int):
|
|||||||
return combine_weights, dispatch_mask
|
return combine_weights, dispatch_mask
|
||||||
|
|
||||||
class HunyuanRMSNorm(nn.Module):
|
class HunyuanRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@ -326,7 +312,7 @@ class HunyuanRMSNorm(nn.Module):
|
|||||||
|
|
||||||
class UNetDown(nn.Module):
|
class UNetDown(nn.Module):
|
||||||
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
||||||
dropout=0.0, device=None, dtype=None):
|
dropout=0.0, device=None, dtype=None, operations=None):
|
||||||
factory_kwargs = {'dtype': dtype, 'device': device}
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -334,8 +320,7 @@ class UNetDown(nn.Module):
|
|||||||
assert self.patch_size in [1, 2, 4, 8]
|
assert self.patch_size in [1, 2, 4, 8]
|
||||||
|
|
||||||
self.model = nn.ModuleList(
|
self.model = nn.ModuleList(
|
||||||
[conv_nd(
|
[operations.Conv2d(
|
||||||
2,
|
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -351,7 +336,8 @@ class UNetDown(nn.Module):
|
|||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
use_scale_shift_norm = True,
|
use_scale_shift_norm = True,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
**factory_kwargs
|
**factory_kwargs,
|
||||||
|
operations = operations
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
for i in range(self.patch_size // 2):
|
for i in range(self.patch_size // 2):
|
||||||
@ -362,7 +348,8 @@ class UNetDown(nn.Module):
|
|||||||
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
down=True,
|
down=True,
|
||||||
**factory_kwargs
|
**factory_kwargs,
|
||||||
|
operations = operations
|
||||||
))
|
))
|
||||||
|
|
||||||
def forward(self, x, t):
|
def forward(self, x, t):
|
||||||
@ -397,7 +384,8 @@ class UNetUp(nn.Module):
|
|||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
use_scale_shift_norm = True,
|
use_scale_shift_norm = True,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
**factory_kwargs
|
**factory_kwargs,
|
||||||
|
operations = operations
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
for i in range(self.patch_size // 2):
|
for i in range(self.patch_size // 2):
|
||||||
@ -408,14 +396,15 @@ class UNetUp(nn.Module):
|
|||||||
use_scale_shift_norm = True,
|
use_scale_shift_norm = True,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
up=True,
|
up=True,
|
||||||
**factory_kwargs
|
**factory_kwargs,
|
||||||
|
operations = operations
|
||||||
))
|
))
|
||||||
|
|
||||||
if out_norm:
|
if out_norm:
|
||||||
self.model.append(nn.Sequential(
|
self.model.append(nn.Sequential(
|
||||||
normalization(hidden_channels, **factory_kwargs),
|
operations.GroupNorm(32, hidden_channels, **factory_kwargs),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv2d(
|
operations.Conv2d(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -424,7 +413,7 @@ class UNetUp(nn.Module):
|
|||||||
),
|
),
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model.append(nn.Conv2d(
|
self.model.append(operations.Conv2d(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -443,14 +432,14 @@ class UNetUp(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class HunyuanTopKGate(nn.Module):
|
class HunyuanTopKGate(nn.Module):
|
||||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.moe_topk = 8
|
self.moe_topk = 8
|
||||||
self.min_capacity = 8
|
self.min_capacity = 8
|
||||||
num_experts = 64
|
num_experts = 64
|
||||||
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
bsz, seq_len, hidden_size = hidden_states.shape
|
bsz, seq_len, hidden_size = hidden_states.shape
|
||||||
@ -463,7 +452,7 @@ class HunyuanTopKGate(nn.Module):
|
|||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
class HunyuanMLP(nn.Module):
|
class HunyuanMLP(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -473,8 +462,8 @@ class HunyuanMLP(nn.Module):
|
|||||||
|
|
||||||
self.act_fn = torch.nn.functional.silu
|
self.act_fn = torch.nn.functional.silu
|
||||||
self.intermediate_size *= 2 # SwiGLU
|
self.intermediate_size *= 2 # SwiGLU
|
||||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||||
if x.ndim == 2:
|
if x.ndim == 2:
|
||||||
@ -494,7 +483,9 @@ class MoELRUCache(nn.Module):
|
|||||||
|
|
||||||
self.last_offload_event = None
|
self.last_offload_event = None
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._gpu_sem = asyncio.Semaphore(1) # maybe 2
|
self._gpu_sem = asyncio.Semaphore(2)
|
||||||
|
self.operations = None
|
||||||
|
self.dtype = None
|
||||||
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
||||||
|
|
||||||
async def _async_offload_to_cpu(self, layer_idx):
|
async def _async_offload_to_cpu(self, layer_idx):
|
||||||
@ -509,12 +500,12 @@ class MoELRUCache(nn.Module):
|
|||||||
|
|
||||||
with torch.cuda.stream(self.offload_stream):
|
with torch.cuda.stream(self.offload_stream):
|
||||||
for index, moe in moe_group:
|
for index, moe in moe_group:
|
||||||
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations)
|
||||||
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||||
if p_gpu.device.type == "meta":
|
if p_gpu.device.type == "meta":
|
||||||
continue
|
continue
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True)
|
||||||
p_cpu.copy_(p_gpu, non_blocking=True)
|
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||||
|
|
||||||
self.cpu_cache[index] = moe_cpu
|
self.cpu_cache[index] = moe_cpu
|
||||||
@ -547,7 +538,7 @@ class MoELRUCache(nn.Module):
|
|||||||
# async loading from cpu -> gpu
|
# async loading from cpu -> gpu
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.stream(self.load_stream):
|
with torch.cuda.stream(self.load_stream):
|
||||||
moe_gpu = HunyuanMLP(moe.config, device="meta")
|
moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations)
|
||||||
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
||||||
p_gpu.data.copy_(p_cpu, non_blocking=True)
|
p_gpu.data.copy_(p_cpu, non_blocking=True)
|
||||||
|
|
||||||
@ -661,17 +652,17 @@ def enough_vram(required_bytes):
|
|||||||
return free > required_bytes
|
return free > required_bytes
|
||||||
|
|
||||||
class HunyuanMoE(nn.Module):
|
class HunyuanMoE(nn.Module):
|
||||||
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.moe_topk = 8
|
self.moe_topk = 8
|
||||||
self.num_experts = 64
|
self.num_experts = 64
|
||||||
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||||
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
|
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||||
if INIT_MOE:
|
if INIT_MOE:
|
||||||
self.experts = nn.ModuleList(
|
self.experts = nn.ModuleList(
|
||||||
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
[HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.experts = []
|
self.experts = []
|
||||||
@ -757,7 +748,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
class HunyuanImage3Attention(nn.Module):
|
class HunyuanImage3Attention(nn.Module):
|
||||||
def __init__(self, config, layer_idx: int):
|
def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -775,15 +766,15 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
||||||
|
|
||||||
# define layers
|
# define layers
|
||||||
self.qkv_proj = nn.Linear(
|
self.qkv_proj = operations.Linear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size_q + 2 * self.hidden_size_kv,
|
self.hidden_size_q + 2 * self.hidden_size_kv,
|
||||||
bias=False
|
bias=False, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False)
|
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
@ -838,16 +829,16 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
return attn_output, past_key_value
|
return attn_output, past_key_value
|
||||||
|
|
||||||
class HunyuanImage3DecoderLayer(nn.Module):
|
class HunyuanImage3DecoderLayer(nn.Module):
|
||||||
def __init__(self, config, layer_idx: int, moe_lru=None):
|
def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config["hidden_size"]
|
self.hidden_size = config["hidden_size"]
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx)
|
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru)
|
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||||
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'])
|
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -887,14 +878,14 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
class HunyuanImage3Model(nn.Module):
|
class HunyuanImage3Model(nn.Module):
|
||||||
def __init__(self, config, moe_lru=None):
|
def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = 128009
|
self.padding_idx = 128009
|
||||||
self.vocab_size = 133120
|
self.vocab_size = 133120
|
||||||
self.config = config
|
self.config = config
|
||||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||||
@ -979,9 +970,11 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class HunyuanImage3ForCausalMM(nn.Module):
|
class HunyuanImage3ForCausalMM(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, operations = None, dtype = None, device = None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config = kwargs
|
||||||
self.config = config
|
self.config = config
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
||||||
|
|
||||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||||
self.patch_embed = UNetDown(
|
self.patch_embed = UNetDown(
|
||||||
@ -990,8 +983,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
in_channels=32,
|
in_channels=32,
|
||||||
hidden_channels=1024,
|
hidden_channels=1024,
|
||||||
out_channels=config["hidden_size"],
|
out_channels=config["hidden_size"],
|
||||||
|
**factory_kwargs
|
||||||
)
|
)
|
||||||
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||||
|
|
||||||
self.final_layer = UNetUp(
|
self.final_layer = UNetUp(
|
||||||
patch_size=1,
|
patch_size=1,
|
||||||
@ -1000,19 +994,22 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
hidden_channels=1024,
|
hidden_channels=1024,
|
||||||
out_channels=32,
|
out_channels=32,
|
||||||
out_norm=True,
|
out_norm=True,
|
||||||
|
**factory_kwargs
|
||||||
)
|
)
|
||||||
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||||
|
|
||||||
self.moe_lru = None
|
self.moe_lru = None
|
||||||
if not INIT_MOE:
|
if not INIT_MOE:
|
||||||
self.moe_lru = MoELRUCache()
|
self.moe_lru = MoELRUCache()
|
||||||
|
self.moe_lru.operations = operations
|
||||||
|
self.moe_lru.dtype = dtype
|
||||||
|
|
||||||
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru)
|
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs)
|
||||||
|
|
||||||
self.pad_id = 128009
|
self.pad_id = 128009
|
||||||
self.vocab_size = 133120
|
self.vocab_size = 133120
|
||||||
|
|
||||||
self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False)
|
self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype)
|
||||||
self.first_step = True
|
self.first_step = True
|
||||||
|
|
||||||
self.kv_cache = None
|
self.kv_cache = None
|
||||||
|
|||||||
@ -63,6 +63,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
sd = {}
|
sd = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
|
if k.startswith("__SKIP__"):
|
||||||
|
continue
|
||||||
tensor = f.get_tensor(k)
|
tensor = f.get_tensor(k)
|
||||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||||
tensor = tensor.to(device=device, copy=True)
|
tensor = tensor.to(device=device, copy=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user