update upstream for flux fixes

This commit is contained in:
doctorpangloss 2025-07-09 06:28:17 -07:00
parent c358ff88a5
commit b268296504
8 changed files with 57 additions and 40 deletions

View File

View File

@ -6,11 +6,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
from ..lightricks.model import Timesteps
from ..flux.layers import EmbedND
from ..modules.attention import optimized_attention_masked
from ...model_management import cast_to
from ..common_dit import pad_to_patch_size
def apply_rotary_emb(x, freqs_cis):
@ -363,7 +363,7 @@ class OmniGen2Transformer2DModel(nn.Module):
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_image_hidden_states = list(map(lambda ref: pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
@ -396,7 +396,7 @@ class OmniGen2Transformer2DModel(nn.Module):
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
image_index_embedding = cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
@ -417,7 +417,7 @@ class OmniGen2Transformer2DModel(nn.Module):
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
hidden_states = pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context

View File

@ -514,6 +514,7 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("black-forest-labs/FLUX.1-Fill-dev", "flux1-fill-dev.safetensors"),
HuggingFile("black-forest-labs/FLUX.1-Canny-dev", "flux1-canny-dev.safetensors"),
HuggingFile("black-forest-labs/FLUX.1-Depth-dev", "flux1-depth-dev.safetensors"),
HuggingFile("black-forest-labs/FLUX.1-Kontext-dev", "flux1-kontext-dev.safetensors"),
HuggingFile("Kijai/flux-fp8", "flux1-dev-fp8.safetensors"),
HuggingFile("Kijai/flux-fp8", "flux1-schnell-fp8.safetensors"),
HuggingFile("Comfy-Org/mochi_preview_repackaged", "split_files/diffusion_models/mochi_preview_bf16.safetensors"),

View File

@ -7,6 +7,7 @@ from . import sd1_clip
from . import sdxl_clip
from . import supported_models_base
from . import utils
from . import model_management
from .text_encoders import ace
from .text_encoders import aura_t5
from .text_encoders import cosmos

View File

@ -1,17 +1,23 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
from .. import sd1_clip
from .llama import Qwen25_3B
import os
from ..component_model import files
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
@ -23,18 +29,25 @@ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8

View File

@ -94,6 +94,26 @@ HIDREAM_1_EDIT_RESOLUTIONS = [
(768, 768),
]
KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
RESOLUTION_MAP = {
"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS,
"SD1.5": SD_RESOLUTIONS,
@ -105,7 +125,9 @@ RESOLUTION_MAP = {
"WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS,
"WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS,
"HiDream 1 Edit": HIDREAM_1_EDIT_RESOLUTIONS,
"Kontext": KONTEXT_RESOLUTIONS,
"Unknown": []
}
RESOLUTION_NAMES = list(RESOLUTION_MAP.keys())

View File

@ -1,4 +1,4 @@
import node_helpers
from comfy import node_helpers
class ReferenceLatent:

View File

@ -1,5 +1,6 @@
from comfy import node_helpers
import comfy.utils
from comfy_extras.constants.resolutions import KONTEXT_RESOLUTIONS
class CLIPTextEncodeFlux:
@ -60,27 +61,6 @@ class FluxDisableGuidance:
return (c,)
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
class FluxKontextImageScale:
@classmethod
def INPUT_TYPES(s):
@ -98,7 +78,7 @@ class FluxKontextImageScale:
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image,)