This commit is contained in:
Silver 2026-07-02 17:47:09 +08:00 committed by GitHub
commit fa5a442195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 160 additions and 16 deletions

View File

@ -226,27 +226,85 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
return x.reshape(B, C, gh, gw)
def _image_position_ids(self, gh, gw, device):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
t_idx = torch.zeros_like(h_idx)
def _image_position_ids(self, gh, gw, device, index=0, h_offset=0, w_offset=0):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) + h_offset
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) + w_offset
t_idx = torch.full_like(h_idx, index)
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
latent_dim = img_tokens.shape[-1]
ref_tokens_list = []
ref_pos_ids_list = []
ref_num_tokens = []
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero")
negative_ref_method = ref_latents_method == "negative_index"
for ref in ref_latents:
ref_b, ref_c, ref_h, ref_w = ref.shape
ref_gh = ref_h
ref_gw = ref_w
if index_ref_method:
index += 1
gh_offset = 0
gw_offset = 0
elif negative_ref_method:
index -= 1
gh_offset = 0
gw_offset = 0
else: # offset/default
index = 1
gh_offset = 0
gw_offset = 0
if ref_gh + h > ref_gw + w:
gw_offset = w
else:
gh_offset = h
h = max(h, ref_gh + gh_offset)
w = max(w, ref_gw + gw_offset)
ref_tokens = self._img_to_tokens(ref)
ref_tokens_list.append(ref_tokens)
ref_num_tokens.append(ref_tokens.shape[1])
ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset)
ref_pos_ids_list.append(ref_pos)
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0
L = L_text + L_img + L_ref
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, L_text:] = img_tokens
x_full[:, L_text:L_text+L_img] = img_tokens
curr_idx = L_text + L_img
for ref_tokens in ref_tokens_list:
ref_len = ref_tokens.shape[1]
x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens
curr_idx += ref_len
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
img_pos = self._image_position_ids(gh, gw, device)
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
pos_ids_all = [text_pos, img_pos]
for ref_pos in ref_pos_ids_list:
pos_ids_all.append(ref_pos)
position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L, 3)
indicator = torch.empty(B, L, dtype=torch.long, device=device)
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
@ -263,20 +321,84 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
transformer_options=transformer_options)
return self._tokens_to_img(out[:, L_text:], gh, gw)
return self._tokens_to_img(out[:, L_text:L_text+L_img], gh, gw)
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
latent_dim = img_tokens.shape[-1]
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
ref_tokens_list = []
ref_pos_ids_list = []
ref_num_tokens = []
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero")
negative_ref_method = ref_latents_method == "negative_index"
for ref in ref_latents:
ref_b, ref_c, ref_h, ref_w = ref.shape
ref_gh = ref_h
ref_gw = ref_w
if index_ref_method:
index += 1
gh_offset = 0
gw_offset = 0
elif negative_ref_method:
index -= 1
gh_offset = 0
gw_offset = 0
else: # offset/default
index = 1
gh_offset = 0
gw_offset = 0
if ref_gh + h > ref_gw + w:
gw_offset = w
else:
gh_offset = h
h = max(h, ref_gh + gh_offset)
w = max(w, ref_gw + gw_offset)
ref_tokens = self._img_to_tokens(ref)
ref_tokens_list.append(ref_tokens)
ref_num_tokens.append(ref_tokens.shape[1])
ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset)
ref_pos_ids_list.append(ref_pos)
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0
L_img_total = L_img + L_ref
x_full = torch.zeros(B, L_img_total, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, :L_img] = img_tokens
curr_idx = L_img
for ref_tokens in ref_tokens_list:
ref_len = ref_tokens.shape[1]
x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens
curr_idx += ref_len
img_pos = self._image_position_ids(gh, gw, device)
pos_ids_all = [img_pos]
for ref_pos in ref_pos_ids_list:
pos_ids_all.append(ref_pos)
position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L_img_total, 3)
indicator = torch.full((B, L_img_total), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out, gh, gw)
out = self._backbone(None, x_full, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out[:, :L_img], gh, gw)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -290,8 +412,11 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
timesteps = 1.0 - timesteps
ref_latents = kwargs.get("ref_latents", None)
ref_latents_method = kwargs.get("ref_latents_method", "index")
# unconditional pass
if context is None:
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
return -self._run_image_only(x, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method)
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method)

View File

@ -2267,6 +2267,7 @@ class QwenImage(BaseModel):
class Ideogram4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -2277,6 +2278,24 @@ class Ideogram4(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 128, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class Krea2(BaseModel):