mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
operations + device + dtype | checkpoint skip
This commit is contained in:
parent
88c350bfed
commit
76e14d69b2
@ -90,7 +90,8 @@ class TimestepEmbedder(nn.Module):
|
||||
max_period=10000,
|
||||
out_size=None,
|
||||
dtype=None,
|
||||
device=None
|
||||
device=None,
|
||||
operations = None
|
||||
):
|
||||
factory_kwargs = {'dtype': dtype, 'device': device}
|
||||
super().__init__()
|
||||
@ -100,9 +101,9 @@ class TimestepEmbedder(nn.Module):
|
||||
out_size = hidden_size
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||
operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
|
||||
def forward(self, t):
|
||||
@ -156,7 +157,6 @@ def build_2d_rope(
|
||||
image_infos_list = [image_infos]
|
||||
sample_seq_lens = [seq_len]
|
||||
|
||||
# Prepare position indices for each sample
|
||||
x_sections = []
|
||||
y_sections = []
|
||||
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
||||
@ -168,13 +168,10 @@ def build_2d_rope(
|
||||
y_sections.append(torch.arange(last_pos, L))
|
||||
x_sections.append(torch.arange(last_pos, L))
|
||||
elif h is None:
|
||||
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
|
||||
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
||||
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
||||
continue
|
||||
else:
|
||||
# Interleave data has overlapped positions for noised image and the successive clean image,
|
||||
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
|
||||
pass
|
||||
# current image
|
||||
beta_y = L + (w * h - h) / 2
|
||||
@ -209,7 +206,7 @@ def build_2d_rope(
|
||||
|
||||
|
||||
def build_batch_2d_rope(
|
||||
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
|
||||
seq_len: int, n_elem: int, image_infos = None,
|
||||
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
||||
return_all_pos: bool = False,
|
||||
):
|
||||
@ -261,17 +258,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
|
||||
def normalization(channels, **kwargs):
|
||||
return nn.GroupNorm(32, channels, **kwargs)
|
||||
|
||||
def topkgating(logits: torch.Tensor, topk: int):
|
||||
logits = logits.float()
|
||||
gates = F.softmax(logits, dim=1)
|
||||
@ -311,9 +297,9 @@ def topkgating(logits: torch.Tensor, topk: int):
|
||||
return combine_weights, dispatch_mask
|
||||
|
||||
class HunyuanRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -326,7 +312,7 @@ class HunyuanRMSNorm(nn.Module):
|
||||
|
||||
class UNetDown(nn.Module):
|
||||
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
||||
dropout=0.0, device=None, dtype=None):
|
||||
dropout=0.0, device=None, dtype=None, operations=None):
|
||||
factory_kwargs = {'dtype': dtype, 'device': device}
|
||||
super().__init__()
|
||||
|
||||
@ -334,8 +320,7 @@ class UNetDown(nn.Module):
|
||||
assert self.patch_size in [1, 2, 4, 8]
|
||||
|
||||
self.model = nn.ModuleList(
|
||||
[conv_nd(
|
||||
2,
|
||||
[operations.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_channels,
|
||||
kernel_size=3,
|
||||
@ -351,7 +336,8 @@ class UNetDown(nn.Module):
|
||||
out_channels=out_channels,
|
||||
use_scale_shift_norm = True,
|
||||
dropout=dropout,
|
||||
**factory_kwargs
|
||||
**factory_kwargs,
|
||||
operations = operations
|
||||
))
|
||||
else:
|
||||
for i in range(self.patch_size // 2):
|
||||
@ -362,7 +348,8 @@ class UNetDown(nn.Module):
|
||||
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
||||
dropout=dropout,
|
||||
down=True,
|
||||
**factory_kwargs
|
||||
**factory_kwargs,
|
||||
operations = operations
|
||||
))
|
||||
|
||||
def forward(self, x, t):
|
||||
@ -397,7 +384,8 @@ class UNetUp(nn.Module):
|
||||
out_channels=hidden_channels,
|
||||
use_scale_shift_norm = True,
|
||||
dropout=dropout,
|
||||
**factory_kwargs
|
||||
**factory_kwargs,
|
||||
operations = operations
|
||||
))
|
||||
else:
|
||||
for i in range(self.patch_size // 2):
|
||||
@ -408,14 +396,15 @@ class UNetUp(nn.Module):
|
||||
use_scale_shift_norm = True,
|
||||
dropout=dropout,
|
||||
up=True,
|
||||
**factory_kwargs
|
||||
**factory_kwargs,
|
||||
operations = operations
|
||||
))
|
||||
|
||||
if out_norm:
|
||||
self.model.append(nn.Sequential(
|
||||
normalization(hidden_channels, **factory_kwargs),
|
||||
operations.GroupNorm(32, hidden_channels, **factory_kwargs),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(
|
||||
operations.Conv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@ -424,7 +413,7 @@ class UNetUp(nn.Module):
|
||||
),
|
||||
))
|
||||
else:
|
||||
self.model.append(nn.Conv2d(
|
||||
self.model.append(operations.Conv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@ -443,14 +432,14 @@ class UNetUp(nn.Module):
|
||||
return x
|
||||
|
||||
class HunyuanTopKGate(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.moe_topk = 8
|
||||
self.min_capacity = 8
|
||||
num_experts = 64
|
||||
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
||||
self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, hidden_size = hidden_states.shape
|
||||
@ -463,7 +452,7 @@ class HunyuanTopKGate(nn.Module):
|
||||
return gate_output
|
||||
|
||||
class HunyuanMLP(nn.Module):
|
||||
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
||||
def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -473,8 +462,8 @@ class HunyuanMLP(nn.Module):
|
||||
|
||||
self.act_fn = torch.nn.functional.silu
|
||||
self.intermediate_size *= 2 # SwiGLU
|
||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
||||
self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
def forward(self, x):
|
||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||
if x.ndim == 2:
|
||||
@ -494,7 +483,9 @@ class MoELRUCache(nn.Module):
|
||||
|
||||
self.last_offload_event = None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._gpu_sem = asyncio.Semaphore(1) # maybe 2
|
||||
self._gpu_sem = asyncio.Semaphore(2)
|
||||
self.operations = None
|
||||
self.dtype = None
|
||||
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
||||
|
||||
async def _async_offload_to_cpu(self, layer_idx):
|
||||
@ -509,12 +500,12 @@ class MoELRUCache(nn.Module):
|
||||
|
||||
with torch.cuda.stream(self.offload_stream):
|
||||
for index, moe in moe_group:
|
||||
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
||||
moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations)
|
||||
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||
if p_gpu.device.type == "meta":
|
||||
continue
|
||||
with torch.no_grad():
|
||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True)
|
||||
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
@ -547,7 +538,7 @@ class MoELRUCache(nn.Module):
|
||||
# async loading from cpu -> gpu
|
||||
with torch.no_grad():
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
moe_gpu = HunyuanMLP(moe.config, device="meta")
|
||||
moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations)
|
||||
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
||||
p_gpu.data.copy_(p_cpu, non_blocking=True)
|
||||
|
||||
@ -661,17 +652,17 @@ def enough_vram(required_bytes):
|
||||
return free > required_bytes
|
||||
|
||||
class HunyuanMoE(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.moe_topk = 8
|
||||
self.num_experts = 64
|
||||
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
||||
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
|
||||
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||
if INIT_MOE:
|
||||
self.experts = nn.ModuleList(
|
||||
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
||||
[HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)]
|
||||
)
|
||||
else:
|
||||
self.experts = []
|
||||
@ -757,7 +748,7 @@ class HunyuanMoE(nn.Module):
|
||||
return output
|
||||
|
||||
class HunyuanImage3Attention(nn.Module):
|
||||
def __init__(self, config, layer_idx: int):
|
||||
def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -775,15 +766,15 @@ class HunyuanImage3Attention(nn.Module):
|
||||
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
||||
|
||||
# define layers
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.qkv_proj = operations.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size_q + 2 * self.hidden_size_kv,
|
||||
bias=False
|
||||
bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False)
|
||||
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
||||
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
@ -838,16 +829,16 @@ class HunyuanImage3Attention(nn.Module):
|
||||
return attn_output, past_key_value
|
||||
|
||||
class HunyuanImage3DecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx: int, moe_lru=None):
|
||||
def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx)
|
||||
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru)
|
||||
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'])
|
||||
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -887,14 +878,14 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
class HunyuanImage3Model(nn.Module):
|
||||
def __init__(self, config, moe_lru=None):
|
||||
def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.padding_idx = 128009
|
||||
self.vocab_size = 133120
|
||||
self.config = config
|
||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
||||
)
|
||||
|
||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||
@ -979,9 +970,11 @@ class HunyuanImage3Model(nn.Module):
|
||||
|
||||
|
||||
class HunyuanImage3ForCausalMM(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, operations = None, dtype = None, device = None, **kwargs):
|
||||
super().__init__()
|
||||
config = kwargs
|
||||
self.config = config
|
||||
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
||||
|
||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||
self.patch_embed = UNetDown(
|
||||
@ -990,8 +983,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
in_channels=32,
|
||||
hidden_channels=1024,
|
||||
out_channels=config["hidden_size"],
|
||||
**factory_kwargs
|
||||
)
|
||||
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||
|
||||
self.final_layer = UNetUp(
|
||||
patch_size=1,
|
||||
@ -1000,19 +994,22 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
hidden_channels=1024,
|
||||
out_channels=32,
|
||||
out_norm=True,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||
|
||||
self.moe_lru = None
|
||||
if not INIT_MOE:
|
||||
self.moe_lru = MoELRUCache()
|
||||
self.moe_lru.operations = operations
|
||||
self.moe_lru.dtype = dtype
|
||||
|
||||
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru)
|
||||
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs)
|
||||
|
||||
self.pad_id = 128009
|
||||
self.vocab_size = 133120
|
||||
|
||||
self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False)
|
||||
self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype)
|
||||
self.first_step = True
|
||||
|
||||
self.kv_cache = None
|
||||
|
||||
@ -63,6 +63,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
if k.startswith("__SKIP__"):
|
||||
continue
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user