mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Added an async loading and offloading of moe layers, having consistent memory with oom errors. Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency
1102 lines
41 KiB
Python
1102 lines
41 KiB
Python
import os
|
|
import math
|
|
import time
|
|
import torch
|
|
import psutil
|
|
import asyncio
|
|
import threading
|
|
import torch.nn as nn
|
|
from pathlib import Path
|
|
import concurrent.futures
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
from safetensors import safe_open
|
|
from transformers.cache_utils import StaticCache
|
|
from typing import Optional, Tuple, Any, List, Dict
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
|
|
|
INIT_MOE = torch.cuda.device_count() != 1
|
|
|
|
if not INIT_MOE:
|
|
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
|
|
|
torch.cuda.set_device(0)
|
|
props = torch.cuda.get_device_properties(0)
|
|
|
|
LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
|
- psutil.Process(os.getpid()).memory_info().rss
|
|
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
|
|
|
|
class HunyuanStaticCache(StaticCache):
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
if hasattr(self, "key_cache") and hasattr(self, "value_cache"):
|
|
if self.key_cache[layer_idx].device != key_states.device:
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
|
k_out = self.key_cache[layer_idx]
|
|
v_out = self.value_cache[layer_idx]
|
|
key_states = key_states.to(k_out.dtype)
|
|
value_states = value_states.to(v_out.dtype)
|
|
else:
|
|
if self.layers[layer_idx].keys is None:
|
|
self.layers[layer_idx].lazy_initialization(key_states)
|
|
k_out = self.layers[layer_idx].keys
|
|
v_out = self.layers[layer_idx].values
|
|
|
|
if cache_position is None:
|
|
k_out.copy_(key_states)
|
|
v_out.copy_(value_states)
|
|
else:
|
|
if cache_position.dim() == 1:
|
|
k_out.index_copy_(2, cache_position, key_states)
|
|
v_out.index_copy_(2, cache_position, value_states)
|
|
|
|
else:
|
|
assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}"
|
|
batch_size, _ = cache_position.shape
|
|
for i in range(batch_size):
|
|
unbatched_dim = 1
|
|
k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i])
|
|
v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i])
|
|
|
|
return k_out, v_out
|
|
|
|
def real_batched_index_select(t, dim, idx):
|
|
return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))])
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period)
|
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
/ half
|
|
).to(device=t.device)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat(
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
)
|
|
return embedding
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
def __init__(self,
|
|
hidden_size,
|
|
act_layer=nn.GELU,
|
|
frequency_embedding_size=256,
|
|
max_period=10000,
|
|
out_size=None,
|
|
dtype=None,
|
|
device=None
|
|
):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
self.max_period = max_period
|
|
if out_size is None:
|
|
out_size = hidden_size
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
|
act_layer(),
|
|
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
|
)
|
|
|
|
def forward(self, t):
|
|
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
def _to_tuple(x, dim=2):
|
|
if isinstance(x, int):
|
|
return (x,) * dim
|
|
return x
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2):
|
|
|
|
start = _to_tuple(start, dim=dim)
|
|
stop = _to_tuple(args[0], dim=dim)
|
|
num = [stop[i] - start[i] for i in range(dim)]
|
|
num_int = [int(x) for x in num]
|
|
num = num_int
|
|
|
|
axis_grid = []
|
|
for i in range(dim):
|
|
a, b, n = start[i], stop[i], num[i]
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
axis_grid.append(g)
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
|
grid = torch.stack(grid, dim=0)
|
|
|
|
return grid
|
|
|
|
def build_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
|
|
assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}."
|
|
|
|
# theta
|
|
if base_rescale_factor != 1.0:
|
|
base *= base_rescale_factor ** (n_elem / (n_elem - 2))
|
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
|
theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2]
|
|
|
|
# position indices
|
|
if image_infos is None:
|
|
image_infos = []
|
|
|
|
image_infos_list = [image_infos]
|
|
sample_seq_lens = [seq_len]
|
|
|
|
# Prepare position indices for each sample
|
|
x_sections = []
|
|
y_sections = []
|
|
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
|
last_pos = 0
|
|
for sec_slice, (h, w) in sample_image_infos:
|
|
L = sec_slice.start # start from 0, so image_slice.start is just L
|
|
# previous text
|
|
if last_pos < L:
|
|
y_sections.append(torch.arange(last_pos, L))
|
|
x_sections.append(torch.arange(last_pos, L))
|
|
elif h is None:
|
|
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
|
|
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
continue
|
|
else:
|
|
# Interleave data has overlapped positions for noised image and the successive clean image,
|
|
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
|
|
pass
|
|
# current image
|
|
beta_y = L + (w * h - h) / 2
|
|
beta_x = L + (w * h - w) / 2
|
|
grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w]
|
|
grid = grid.reshape(2, -1) # (y, x)
|
|
y_sections.append(grid[0])
|
|
x_sections.append(grid[1])
|
|
# step
|
|
last_pos = L + w * h
|
|
# final text
|
|
y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
|
|
x_pos = torch.cat(x_sections).long()
|
|
y_pos = torch.cat(y_sections).long()
|
|
# If there are overlap positions, we need to remove them.
|
|
x_pos = x_pos[:seq_len]
|
|
y_pos = y_pos[:seq_len]
|
|
all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2]
|
|
|
|
# calc rope
|
|
idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2)
|
|
|
|
cos = torch.cos(idx_theta)
|
|
sin = torch.sin(idx_theta)
|
|
|
|
if return_all_pos:
|
|
return cos, sin, all_pos
|
|
|
|
return cos, sin
|
|
|
|
|
|
def build_batch_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
cos_list, sin_list, all_pos_list = [], [], []
|
|
if image_infos is None:
|
|
image_infos = [None]
|
|
for i, image_info in enumerate(image_infos):
|
|
res = build_2d_rope(
|
|
seq_len, n_elem, image_infos=image_info, device=device,
|
|
base=base, base_rescale_factor=base_rescale_factor,
|
|
return_all_pos=return_all_pos,
|
|
)
|
|
if return_all_pos:
|
|
cos, sin, all_pos = res
|
|
else:
|
|
cos, sin = res
|
|
all_pos = None
|
|
cos_list.append(cos)
|
|
sin_list.append(sin)
|
|
all_pos_list.append(all_pos)
|
|
|
|
stacked_cos = torch.stack(cos_list, dim=0)
|
|
stacked_sin = torch.stack(sin_list, dim=0)
|
|
|
|
if return_all_pos:
|
|
return stacked_cos, stacked_sin, all_pos_list
|
|
|
|
return stacked_cos, stacked_sin
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2:]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
if position_ids is not None:
|
|
cos = cos[position_ids]
|
|
sin = sin[position_ids]
|
|
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
def default(val, d):
|
|
return val if val is not None else d
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
if dims == 1:
|
|
return nn.Conv1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.Conv3d(*args, **kwargs)
|
|
|
|
def normalization(channels, **kwargs):
|
|
return nn.GroupNorm(32, channels, **kwargs)
|
|
|
|
def topkgating(
|
|
logits: torch.Tensor,
|
|
topk: int,
|
|
norm_topk_prob: bool = True,
|
|
):
|
|
logits = logits.float()
|
|
gates = F.softmax(logits, dim=1)
|
|
|
|
values_all, indices_all = torch.topk(gates, topk, dim=1)
|
|
expert_weight = values_all[:, :topk]
|
|
expert_index = indices_all[:, :topk]
|
|
|
|
if norm_topk_prob and topk > 1:
|
|
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
|
expert_weight = expert_weight / denom
|
|
|
|
return expert_weight, expert_index
|
|
|
|
class HunyuanRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class UNetDown(nn.Module):
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList(
|
|
[conv_nd(
|
|
2,
|
|
in_channels=in_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
)]
|
|
)
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=out_channels,
|
|
dropout=dropout,
|
|
**factory_kwargs
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
|
dropout=dropout,
|
|
down=True,
|
|
**factory_kwargs
|
|
))
|
|
|
|
def forward(self, x, t):
|
|
assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
_, _, token_h, token_w = x.shape
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
return x, token_h, token_w
|
|
|
|
|
|
class UNetUp(nn.Module):
|
|
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None, operations = None, out_norm=False):
|
|
operations = operations or nn
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList()
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=in_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
dropout=dropout,
|
|
**factory_kwargs
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=in_channels if i == 0 else hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
dropout=dropout,
|
|
up=True,
|
|
**factory_kwargs
|
|
))
|
|
|
|
if out_norm:
|
|
self.model.append(nn.Sequential(
|
|
normalization(hidden_channels, **factory_kwargs),
|
|
nn.SiLU(),
|
|
nn.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
),
|
|
))
|
|
else:
|
|
self.model.append(nn.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
))
|
|
|
|
# batch_size, seq_len, model_dim
|
|
def forward(self, x, t, token_h, token_w):
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w)
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
class HunyuanTopKGate(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.min_capacity = 8
|
|
num_experts = 64
|
|
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
|
|
|
self.norm_topk_prob = True
|
|
|
|
def forward(self, hidden_states):
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, hidden_size)
|
|
if self.wg.weight.dtype == torch.float32:
|
|
hidden_states = hidden_states.float()
|
|
logits = self.wg(hidden_states)
|
|
gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,)
|
|
|
|
return gate_output
|
|
|
|
class HunyuanMLP(nn.Module):
|
|
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hidden_size = config["hidden_size"]
|
|
|
|
self.intermediate_size = 3072
|
|
|
|
self.act_fn = torch.nn.functional.silu
|
|
self.intermediate_size *= 2 # SwiGLU
|
|
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
|
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
|
def forward(self, x):
|
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
|
if x.ndim == 2:
|
|
x = x.unsqueeze(0)
|
|
gate_and_up_proj = self.gate_and_up_proj(x)
|
|
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
|
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
|
return down_proj
|
|
|
|
class MoELRUCache(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gpu_cache = OrderedDict()
|
|
self.cpu_cache = OrderedDict()
|
|
self.offload_stream = torch.cuda.Stream()
|
|
self.load_stream = torch.cuda.Stream()
|
|
|
|
self.last_offload_event = None
|
|
self._loop = asyncio.new_event_loop()
|
|
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
|
|
|
async def _async_offload_to_cpu(self, layer_idx):
|
|
# async offload from gpu (removed)
|
|
|
|
num_experts = 64
|
|
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
|
for i in range(num_experts)
|
|
if (layer_idx * num_experts + i) in self.gpu_cache]
|
|
event = torch.cuda.Event()
|
|
|
|
with torch.cuda.stream(self.offload_stream):
|
|
for index, moe in moe_group:
|
|
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
|
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
|
if p_gpu.device.type == "meta":
|
|
continue
|
|
with torch.no_grad():
|
|
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
|
p_cpu.copy_(p_gpu, non_blocking=True)
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
|
|
self.offload_stream.record_event(event)
|
|
|
|
self.last_offload_event = event
|
|
|
|
def finalize_offload_layer():
|
|
event.synchronize()
|
|
for index, moe in moe_group:
|
|
moe.to("meta")
|
|
self.gpu_cache.pop(index, None)
|
|
del moe
|
|
torch.cuda.empty_cache()
|
|
|
|
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
|
|
|
async def _async_load_to_gpu(self, index, moe):
|
|
|
|
# if enough memory load, otherwise wait for offload
|
|
while True:
|
|
free_bytes, _ = torch.cuda.mem_get_info()
|
|
if free_bytes > 2 * MOE_LAYER_SIZE:
|
|
break
|
|
|
|
self.last_offload_event.synchronize()
|
|
torch.cuda.empty_cache()
|
|
await asyncio.sleep(0.01)
|
|
|
|
# async loading from cpu -> gpu
|
|
with torch.cuda.stream(self.load_stream):
|
|
moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
|
|
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
|
with torch.no_grad():
|
|
p_gpu.data = torch.empty_like(p_cpu, device="cuda")
|
|
p_gpu.copy_(p_cpu, non_blocking=True)
|
|
|
|
def finalize_load():
|
|
self.gpu_cache[index] = moe_gpu
|
|
self.cpu_cache.pop(index, None)
|
|
|
|
threading.Thread(target=finalize_load, daemon=True).start()
|
|
|
|
def add_cpu(self, moe, index):
|
|
moe_cpu = moe.to("cpu")
|
|
|
|
for _, p in moe_cpu.named_parameters():
|
|
if not p.is_pinned():
|
|
p.data = p.data.pin_memory()
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
self.cpu_cache.move_to_end(index)
|
|
|
|
class LazyMoELoader(nn.Module):
|
|
def __init__(self, cache, config):
|
|
super().__init__()
|
|
self.cache = cache
|
|
self.config = config
|
|
self._loop = cache._loop
|
|
|
|
def get_checkpoint(self):
|
|
comfyui_dir = Path.home() / "ComfyUI"
|
|
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
|
checkpoint = checkpoint.resolve()
|
|
if not os.path.exists(checkpoint):
|
|
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
|
return checkpoint
|
|
|
|
def lazy_init(self, layer_idx, expert_idx):
|
|
checkpoint = self.get_checkpoint()
|
|
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
|
|
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
|
|
sd = {}
|
|
|
|
with safe_open(checkpoint, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
if k.startswith(prefix) or k.startswith(additional_prefix):
|
|
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
|
|
sd[new_k] = f.get_tensor(k)
|
|
|
|
return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd)
|
|
|
|
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
|
|
|
def _schedule_disk_load(self, layer_idx, expert_idx):
|
|
|
|
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
|
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
|
|
def _on_disk_loaded(fut):
|
|
moe_cpu = fut.result()
|
|
def _add_cpu_in_main_thread():
|
|
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
|
self.cache._loop
|
|
)
|
|
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
|
|
|
future.add_done_callback(_on_disk_loaded)
|
|
return future
|
|
|
|
def enough_vram(required_bytes):
|
|
free, total = torch.cuda.mem_get_info()
|
|
return free > required_bytes
|
|
|
|
class HunyuanMoE(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.num_experts = 64
|
|
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
|
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
|
|
if INIT_MOE:
|
|
self.experts = nn.ModuleList(
|
|
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
|
)
|
|
else:
|
|
self.experts = []
|
|
self.moe_lru = moe_lru
|
|
|
|
def forward(self, hidden_states):
|
|
if not INIT_MOE:
|
|
torch.cuda.set_device(0)
|
|
else:
|
|
torch.cuda.set_device(hidden_states.device.index)
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
|
|
hidden_states_mlp = self.shared_mlp(hidden_states)
|
|
|
|
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
|
|
|
with torch.cuda.nvtx.range("MoE"):
|
|
expert_weight, expert_index = self.gate(hidden_states)
|
|
|
|
combined_output = torch.zeros_like(reshaped_input)
|
|
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
|
|
|
|
per_pos, per_tokens, per_weights = [], [], []
|
|
for e, _ in experts_list:
|
|
token_mask = (expert_index == e)
|
|
|
|
token_ids = token_mask.nonzero(as_tuple=False)
|
|
token_positions = token_ids[:, 0]
|
|
topk_slot = token_ids[:, 1]
|
|
|
|
tokens = reshaped_input[token_positions]
|
|
weights = expert_weight[token_positions, topk_slot]
|
|
|
|
per_pos.append(token_positions)
|
|
per_tokens.append(tokens)
|
|
per_weights.append(weights)
|
|
|
|
lengths = [t.shape[0] for t in per_tokens]
|
|
E = len(per_tokens)
|
|
L = max(lengths)
|
|
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
|
|
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
|
|
for i, t in enumerate(per_tokens):
|
|
tokens_padded[i, : t.shape[0]] = t
|
|
weights_padded[i, : t.shape[0]] = per_weights[i]
|
|
|
|
l1, l2 = [], []
|
|
for _, expert in experts_list:
|
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
|
expert = expert.result()
|
|
l1.append(expert.gate_and_up_proj)
|
|
l2.append(expert.down_proj)
|
|
|
|
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
|
|
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
|
|
|
|
W1_T = W1.transpose(1, 2)
|
|
W2_T = W2.transpose(1, 2)
|
|
|
|
# wait for enough vram for the computations
|
|
while not enough_vram(5*(1024 ** 3)):
|
|
event = self.moe_lru.last_offload_event
|
|
if event is not None and not event.query():
|
|
time.sleep(0.001)
|
|
|
|
x = torch.bmm(tokens_padded, W1_T)
|
|
x = F.silu(x)
|
|
|
|
x1, x2 = x.chunk(2, dim=2)
|
|
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
|
|
|
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
|
|
|
for i, token_positions in enumerate(per_pos):
|
|
Ni = lengths[i]
|
|
out_i = out_padded[i, :Ni]
|
|
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
|
|
|
|
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
|
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
|
#expert_outputs = []
|
|
#for chunk, expert in zip(chunks, self.experts):
|
|
# expert_outputs.append(expert(chunk))
|
|
|
|
#expert_output = torch.cat(expert_outputs, dim=0)
|
|
#combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
|
|
|
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
|
|
|
output = hidden_states_mlp + combined_output
|
|
|
|
return output
|
|
|
|
class HunyuanImage3Attention(nn.Module):
|
|
def __init__(self, config, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.attention_type = 'self'
|
|
|
|
self.hidden_size = config["hidden_size"]
|
|
self.num_heads = config["num_attention_heads"]
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = 8
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config["max_position_embeddings"]
|
|
self.rope_theta = 10000.0
|
|
self.is_causal = True
|
|
self.hidden_size_q = self.head_dim * self.num_heads
|
|
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
|
|
|
# define layers
|
|
self.qkv_proj = nn.Linear(
|
|
self.hidden_size,
|
|
self.hidden_size_q + 2 * self.hidden_size_kv,
|
|
bias=False
|
|
)
|
|
self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False)
|
|
|
|
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
|
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
past_key_value,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv_states = self.qkv_proj(hidden_states)
|
|
qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2,
|
|
self.head_dim)
|
|
query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3)
|
|
|
|
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = custom_pos_emb
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
query_states = self.query_layernorm(query_states)
|
|
key_states = self.key_layernorm(key_states)
|
|
|
|
query_states = query_states.to(value_states.dtype)
|
|
key_states = key_states.to(value_states.dtype)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"cache_position": position_ids}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
query_states = query_states.to(key_states.dtype)
|
|
|
|
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
|
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
class HunyuanImage3DecoderLayer(nn.Module):
|
|
def __init__(self, config, layer_idx: int, moe_lru=None):
|
|
super().__init__()
|
|
self.hidden_size = config["hidden_size"]
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx)
|
|
|
|
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru)
|
|
|
|
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
|
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
use_cache: Optional[bool] = False,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.FloatTensor | Any]:
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, past_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states, past_key_value)
|
|
|
|
return outputs
|
|
|
|
class HunyuanImage3Model(nn.Module):
|
|
def __init__(self, config, moe_lru=None):
|
|
super().__init__()
|
|
self.padding_idx = 128009
|
|
self.vocab_size = 133120
|
|
self.config = config
|
|
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
|
)
|
|
|
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
|
|
|
self.shared_tensor = None
|
|
self.moe_lru = moe_lru
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache = True,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
mode: str = "gen_image",
|
|
first_step: Optional[bool] = None,
|
|
gen_timestep_scatter_index: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
next_decoder_cache = None
|
|
next_layers = 0
|
|
sparse_interval = max(1, len(self.layers) // 3)
|
|
|
|
if len(self.layers[0].mlp.experts) == 0:
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
|
|
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
|
|
|
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
|
|
|
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
|
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
|
|
raise ValueError("Problem with offloading")
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
|
next_layers += 1
|
|
|
|
with torch.no_grad():
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
mode=mode,
|
|
first_step=first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
|
|
if layer_idx >= 0:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.moe_lru._async_offload_to_cpu(layer_idx),
|
|
self.moe_lru._loop
|
|
)
|
|
self.layers[layer_idx].mlp.experts = []
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[1]
|
|
|
|
next_cache = None
|
|
if use_cache:
|
|
next_cache = next_decoder_cache
|
|
|
|
return tuple(v for v in [hidden_states, next_cache] if v is not None)
|
|
|
|
|
|
class HunyuanImage3ForCausalMM(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
self.patch_embed = UNetDown(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=32,
|
|
hidden_channels=1024,
|
|
out_channels=config["hidden_size"],
|
|
)
|
|
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
|
|
self.final_layer = UNetUp(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=config["hidden_size"],
|
|
hidden_channels=1024,
|
|
out_channels=32,
|
|
out_norm=True,
|
|
)
|
|
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
|
|
self.moe_lru = None
|
|
if not INIT_MOE:
|
|
self.moe_lru = MoELRUCache()
|
|
|
|
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru)
|
|
|
|
self.pad_id = 128009
|
|
self.vocab_size = 133120
|
|
|
|
self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False)
|
|
self.first_step = True
|
|
|
|
self.kv_cache = None
|
|
self.token_dims = ()
|
|
|
|
@staticmethod
|
|
def get_pos_emb(custom_pos_emb, position_ids):
|
|
cos, sin = custom_pos_emb
|
|
cos = real_batched_index_select(cos, dim=1, idx=position_ids)
|
|
sin = real_batched_index_select(sin, dim=1, idx=position_ids)
|
|
return cos, sin
|
|
|
|
def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step):
|
|
bsz, seq_len, n_embd = x.shape
|
|
if first_step:
|
|
image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd)
|
|
else:
|
|
image_output = x[:, 1:, :]
|
|
timestep_emb = self.time_embed_2(timestep)
|
|
pred = self.final_layer(image_output, timestep_emb, token_h, token_w)
|
|
return pred
|
|
|
|
def forward(self, x, condition, timestep, **kwargs):
|
|
|
|
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
|
|
|
gen_timestep_scatter_index = 4
|
|
|
|
with torch.no_grad():
|
|
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
|
|
|
if self.first_step:
|
|
token_height, token_width = x[:, -2:, 0].tolist()[0]
|
|
self.token_dims = (int(token_height), int(token_width))
|
|
x = x[:, :-2, :]
|
|
else:
|
|
token_height, token_width = self.token_dims
|
|
|
|
img_slices = []
|
|
|
|
for i in range(x.size(0)):
|
|
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
|
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
|
|
|
vit_start = vae_end + 1
|
|
vit_end = joint_image.size(1) - 1
|
|
|
|
joint_slices_i = [
|
|
slice(vae_start, vae_end),
|
|
slice(vit_start, vit_end),
|
|
]
|
|
gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)]
|
|
img_slices.append(joint_slices_i + gen_slices_i)
|
|
|
|
img_s = img_slices[0]
|
|
rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]]
|
|
|
|
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
|
t_emb = self.time_embed(cond_timestep)
|
|
|
|
bsz, seq_len, n_embd = inputs_embeds.shape
|
|
|
|
if self.first_step:
|
|
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
else:
|
|
t_emb = self.time_embed(timestep)
|
|
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
|
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
|
x = torch.cat([timestep_emb, x], dim=1)
|
|
|
|
#/////////////
|
|
# cond_vae_images
|
|
|
|
# cond_timestep_scatter_index
|
|
with torch.no_grad():
|
|
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
|
|
inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1)
|
|
|
|
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
|
for i in range(bsz):
|
|
for _, image_slice in enumerate(img_slices[i]):
|
|
attention_mask[i, image_slice, image_slice] = True
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# pos embed
|
|
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
|
cos, sin = build_batch_2d_rope(
|
|
image_infos=rope_image_info,
|
|
seq_len=inputs_embeds.shape[1],
|
|
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
|
base=10000.0,
|
|
)
|
|
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
|
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
|
|
|
if self.kv_cache is None:
|
|
# TODO: should change when higgsv2 gets merged
|
|
self.kv_cache = HunyuanStaticCache(
|
|
config=self.config,
|
|
batch_size=x.size(0) * 2,
|
|
max_cache_len = inputs_embeds.shape[1],
|
|
dtype=x.dtype,
|
|
)
|
|
|
|
outputs = self.model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=self.kv_cache,
|
|
inputs_embeds=inputs_embeds,
|
|
custom_pos_emb=custom_pos_emb,
|
|
first_step=self.first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
hidden_states = outputs[0]
|
|
|
|
# safety no-op
|
|
past_key_value = outputs[1]
|
|
if past_key_value is not None:
|
|
self.kv_cache = past_key_value
|
|
|
|
hidden_states = hidden_states.to(inputs_embeds.device)
|
|
img_mask = torch.zeros(hidden_states.size(1))
|
|
img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0
|
|
|
|
diffusion_prediction = self.ragged_final_layer(
|
|
hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step)
|
|
|
|
if self.first_step:
|
|
self.first_step = False
|
|
|
|
return diffusion_prediction
|