mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-09 15:57:24 +08:00
Fix ideogram if model dtype gets set to fp8. (#14291)
This commit is contained in:
parent
4e1f7cb1db
commit
514bb8ba21
@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module):
|
|||||||
llm = self.llm_cond_proj(llm) * text_mask
|
llm = self.llm_cond_proj(llm) * text_mask
|
||||||
h[:, :L_text] = h[:, :L_text] + llm
|
h[:, :L_text] = h[:, :L_text] + llm
|
||||||
|
|
||||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
|
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
|
||||||
|
|
||||||
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
||||||
freqs_cis = precompute_freqs_cis(
|
freqs_cis = precompute_freqs_cis(
|
||||||
@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
|||||||
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
||||||
B = x_chunk.shape[0]
|
B = x_chunk.shape[0]
|
||||||
device = x_chunk.device
|
device = x_chunk.device
|
||||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
img_tokens = self._img_to_tokens(x_chunk)
|
||||||
L_img = img_tokens.shape[1]
|
L_img = img_tokens.shape[1]
|
||||||
L_text = context_chunk.shape[1]
|
L_text = context_chunk.shape[1]
|
||||||
L = L_text + L_img
|
L = L_text + L_img
|
||||||
@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
|||||||
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
||||||
B = x_chunk.shape[0]
|
B = x_chunk.shape[0]
|
||||||
device = x_chunk.device
|
device = x_chunk.device
|
||||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
img_tokens = self._img_to_tokens(x_chunk)
|
||||||
L_img = img_tokens.shape[1]
|
L_img = img_tokens.shape[1]
|
||||||
|
|
||||||
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user