operations + device + dtype | checkpoint skip

This commit is contained in:
Yousef Rafat 2025-12-01 18:53:43 +02:00
parent 88c350bfed
commit 76e14d69b2
2 changed files with 62 additions and 63 deletions

View File

@ -90,7 +90,8 @@ class TimestepEmbedder(nn.Module):
max_period=10000,
out_size=None,
dtype=None,
device=None
device=None,
operations = None
):
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
@ -100,9 +101,9 @@ class TimestepEmbedder(nn.Module):
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
def forward(self, t):
@ -156,7 +157,6 @@ def build_2d_rope(
image_infos_list = [image_infos]
sample_seq_lens = [seq_len]
# Prepare position indices for each sample
x_sections = []
y_sections = []
for sample_id, sample_image_infos in enumerate(image_infos_list):
@ -168,13 +168,10 @@ def build_2d_rope(
y_sections.append(torch.arange(last_pos, L))
x_sections.append(torch.arange(last_pos, L))
elif h is None:
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
continue
else:
# Interleave data has overlapped positions for noised image and the successive clean image,
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
pass
# current image
beta_y = L + (w * h - h) / 2
@ -209,7 +206,7 @@ def build_2d_rope(
def build_batch_2d_rope(
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
seq_len: int, n_elem: int, image_infos = None,
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
return_all_pos: bool = False,
):
@ -261,17 +258,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def default(val, d):
return val if val is not None else d
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
def normalization(channels, **kwargs):
return nn.GroupNorm(32, channels, **kwargs)
def topkgating(logits: torch.Tensor, topk: int):
logits = logits.float()
gates = F.softmax(logits, dim=1)
@ -311,9 +297,9 @@ def topkgating(logits: torch.Tensor, topk: int):
return combine_weights, dispatch_mask
class HunyuanRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype))
self.variance_epsilon = eps
def forward(self, hidden_states):
@ -326,7 +312,7 @@ class HunyuanRMSNorm(nn.Module):
class UNetDown(nn.Module):
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
dropout=0.0, device=None, dtype=None):
dropout=0.0, device=None, dtype=None, operations=None):
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
@ -334,8 +320,7 @@ class UNetDown(nn.Module):
assert self.patch_size in [1, 2, 4, 8]
self.model = nn.ModuleList(
[conv_nd(
2,
[operations.Conv2d(
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=3,
@ -351,7 +336,8 @@ class UNetDown(nn.Module):
out_channels=out_channels,
use_scale_shift_norm = True,
dropout=dropout,
**factory_kwargs
**factory_kwargs,
operations = operations
))
else:
for i in range(self.patch_size // 2):
@ -362,7 +348,8 @@ class UNetDown(nn.Module):
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
dropout=dropout,
down=True,
**factory_kwargs
**factory_kwargs,
operations = operations
))
def forward(self, x, t):
@ -397,7 +384,8 @@ class UNetUp(nn.Module):
out_channels=hidden_channels,
use_scale_shift_norm = True,
dropout=dropout,
**factory_kwargs
**factory_kwargs,
operations = operations
))
else:
for i in range(self.patch_size // 2):
@ -408,14 +396,15 @@ class UNetUp(nn.Module):
use_scale_shift_norm = True,
dropout=dropout,
up=True,
**factory_kwargs
**factory_kwargs,
operations = operations
))
if out_norm:
self.model.append(nn.Sequential(
normalization(hidden_channels, **factory_kwargs),
operations.GroupNorm(32, hidden_channels, **factory_kwargs),
nn.SiLU(),
nn.Conv2d(
operations.Conv2d(
in_channels=hidden_channels,
out_channels=out_channels,
kernel_size=3,
@ -424,7 +413,7 @@ class UNetUp(nn.Module):
),
))
else:
self.model.append(nn.Conv2d(
self.model.append(operations.Conv2d(
in_channels=hidden_channels,
out_channels=out_channels,
kernel_size=3,
@ -443,14 +432,14 @@ class UNetUp(nn.Module):
return x
class HunyuanTopKGate(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.moe_topk = 8
self.min_capacity = 8
num_experts = 64
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device)
def forward(self, hidden_states):
bsz, seq_len, hidden_size = hidden_states.shape
@ -463,7 +452,7 @@ class HunyuanTopKGate(nn.Module):
return gate_output
class HunyuanMLP(nn.Module):
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@ -473,8 +462,8 @@ class HunyuanMLP(nn.Module):
self.act_fn = torch.nn.functional.silu
self.intermediate_size *= 2 # SwiGLU
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
@ -494,7 +483,9 @@ class MoELRUCache(nn.Module):
self.last_offload_event = None
self._loop = asyncio.new_event_loop()
self._gpu_sem = asyncio.Semaphore(1) # maybe 2
self._gpu_sem = asyncio.Semaphore(2)
self.operations = None
self.dtype = None
threading.Thread(target=self._loop.run_forever, daemon=True).start()
async def _async_offload_to_cpu(self, layer_idx):
@ -509,12 +500,12 @@ class MoELRUCache(nn.Module):
with torch.cuda.stream(self.offload_stream):
for index, moe in moe_group:
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations)
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
if p_gpu.device.type == "meta":
continue
with torch.no_grad():
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True)
p_cpu.copy_(p_gpu, non_blocking=True)
self.cpu_cache[index] = moe_cpu
@ -547,7 +538,7 @@ class MoELRUCache(nn.Module):
# async loading from cpu -> gpu
with torch.no_grad():
with torch.cuda.stream(self.load_stream):
moe_gpu = HunyuanMLP(moe.config, device="meta")
moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations)
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
p_gpu.data.copy_(p_cpu, non_blocking=True)
@ -661,17 +652,17 @@ def enough_vram(required_bytes):
return free > required_bytes
class HunyuanMoE(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.moe_topk = 8
self.num_experts = 64
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
if INIT_MOE:
self.experts = nn.ModuleList(
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
[HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)]
)
else:
self.experts = []
@ -757,7 +748,7 @@ class HunyuanMoE(nn.Module):
return output
class HunyuanImage3Attention(nn.Module):
def __init__(self, config, layer_idx: int):
def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@ -775,15 +766,15 @@ class HunyuanImage3Attention(nn.Module):
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
# define layers
self.qkv_proj = nn.Linear(
self.qkv_proj = operations.Linear(
self.hidden_size,
self.hidden_size_q + 2 * self.hidden_size_kv,
bias=False
bias=False, device=device, dtype=dtype
)
self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False)
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -838,16 +829,16 @@ class HunyuanImage3Attention(nn.Module):
return attn_output, past_key_value
class HunyuanImage3DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int, moe_lru=None):
def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None):
super().__init__()
self.hidden_size = config["hidden_size"]
self.layer_idx = layer_idx
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx)
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru)
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations)
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'])
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype)
def forward(
self,
@ -887,14 +878,14 @@ class HunyuanImage3DecoderLayer(nn.Module):
return outputs
class HunyuanImage3Model(nn.Module):
def __init__(self, config, moe_lru=None):
def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None):
super().__init__()
self.padding_idx = 128009
self.vocab_size = 133120
self.config = config
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
)
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
@ -979,9 +970,11 @@ class HunyuanImage3Model(nn.Module):
class HunyuanImage3ForCausalMM(nn.Module):
def __init__(self, config):
def __init__(self, operations = None, dtype = None, device = None, **kwargs):
super().__init__()
config = kwargs
self.config = config
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
self.patch_embed = UNetDown(
@ -990,8 +983,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
in_channels=32,
hidden_channels=1024,
out_channels=config["hidden_size"],
**factory_kwargs
)
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
self.final_layer = UNetUp(
patch_size=1,
@ -1000,19 +994,22 @@ class HunyuanImage3ForCausalMM(nn.Module):
hidden_channels=1024,
out_channels=32,
out_norm=True,
**factory_kwargs
)
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"])
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
self.moe_lru = None
if not INIT_MOE:
self.moe_lru = MoELRUCache()
self.moe_lru.operations = operations
self.moe_lru.dtype = dtype
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru)
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs)
self.pad_id = 128009
self.vocab_size = 133120
self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False)
self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype)
self.first_step = True
self.kv_cache = None

View File

@ -63,6 +63,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
if k.startswith("__SKIP__"):
continue
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)