diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 50ca11a66..3082e6730 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -157,7 +157,11 @@ class PromptServer(ExecutorToClientProgress): @routes.get("/") async def get_root(request): - return web.FileResponse(os.path.join(self.web_root, "index.html")) + response = web.FileResponse(os.path.join(self.web_root, "index.html")) + response.headers['Cache-Control'] = 'no-cache' + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + return response @routes.get("/embeddings") def get_embeddings(self): diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index be9c25b24..1a901d71e 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..modules.attention import optimized_attention from ... import ops +from .. import common_dit def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -407,10 +408,7 @@ class MMDiT(nn.Module): def patchify(self, x): B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) % self.patch_size - pad_w = (self.patch_size - W % self.patch_size) % self.patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) x = x.view( B, C, diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py new file mode 100644 index 000000000..990025521 --- /dev/null +++ b/comfy/ldm/common_dit.py @@ -0,0 +1,8 @@ +import torch + +def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 468a4412b..86a564ec3 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -15,7 +15,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if model_management.is_device_mps(pos.device): + if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu(): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index e7931c16d..db6cf3d22 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( ) from einops import rearrange, repeat +import comfy.ldm.common_dit @dataclass class FluxParams: @@ -42,7 +43,7 @@ class Flux(nn.Module): self.dtype = dtype params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels + self.in_channels = params.in_channels * 2 * 2 self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( @@ -125,10 +126,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, **kwargs): bs, c, h, w = x.shape patch_size = 2 - pad_h = (patch_size - h % 2) % patch_size - pad_w = (patch_size - w % 2) % patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index f52dfac11..459892463 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -10,6 +10,7 @@ from .. import attention from einops import rearrange, repeat from .util import timestep_embedding from .... import ops +from ... import common_dit def default(x, y): if x is not None: @@ -112,9 +113,7 @@ class PatchEmbed(nn.Module): # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." # ) if self.dynamic_img_pad: - pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] - pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode) + x = common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/comfy/lora.py b/comfy/lora.py index 63f9e4d8b..d07fc47fe 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -286,4 +286,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format + if isinstance(model, model_base.Flux): #Diffusers lora Flux + diffusers_keys = utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format + key_map[key_lora] = to + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index fcc82b180..72ea07f39 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -87,6 +87,7 @@ class BaseModel(torch.nn.Module): # todo: ??? self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") + logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -95,6 +96,9 @@ class BaseModel(torch.nn.Module): self.adm_channels = 0 self.concat_keys = () + logging.info("model_type {}".format(model_type.name)) + logging.debug("adm {}".format(self.adm_channels)) + self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -256,11 +260,11 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype # TODO: this needs to be tweaked area = input_shape[0] * math.prod(input_shape[2:]) - return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024) + return (area * model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * math.prod(input_shape[2:]) - return (area * 0.3) * (1024 * 1024) + return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): @@ -614,17 +618,6 @@ class SD3(BaseModel): out['c_crossattn'] = conds.CONDRegular(cross_attn) return out - def memory_required(self, input_shape): - if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - # TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -722,15 +715,3 @@ class Flux(BaseModel): out['c_crossattn'] = conds.CONDRegular(cross_attn) out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) return out - - def memory_required(self, input_shape): - if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * model_management.dtype_size(dtype) * 0.020) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1676255c6..4b0ceb6e8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -130,7 +130,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" - dit_config["in_channels"] = 64 + dit_config["in_channels"] = 16 dit_config["vec_in_dim"] = 768 dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 diff --git a/comfy/model_management.py b/comfy/model_management.py index 7226b0425..8093d495d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -511,7 +511,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required))) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary lowvram_model_memory = 0 @@ -602,6 +602,9 @@ def unet_initial_load_device(parameters, dtype): return cpu_dev +def maximum_vram_for_weights(device=None): + return (get_total_memory(device) * 0.88 - minimum_inference_memory()) + def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)): if args.bf16_unet: return torch.bfloat16 @@ -611,6 +614,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + + fp8_dtype = None + try: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype in supported_dtypes: + fp8_dtype = dtype + break + except: + pass + + if fp8_dtype is not None: + free_model_memory = maximum_vram_for_weights(device) + if model_params * 2 > free_model_memory: + return fp8_dtype + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 @@ -973,7 +991,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma fp16_works = True if fp16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -1016,21 +1034,14 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if device is None: - device = torch.device("cuda") - - try: - props = torch.cuda.get_device_properties(device) - if props.major >= 8: - return True - except AssertionError: - logging.warning("Torch was not compiled with CUDA support") - return False + props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True bf16_works = torch.cuda.is_bf16_supported() if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True diff --git a/comfy/samplers.py b/comfy/samplers.py index ce122553c..6d718e50e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -174,7 +174,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options): for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: + if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break diff --git a/comfy/sd.py b/comfy/sd.py index c37349bf1..c973dd0c1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -535,13 +535,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = utils.calculate_parameters(sd, diffusion_model_prefix) + weight_dtype = utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + unet_weight_dtype.append(weight_dtype) + + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a3b39e5ee..8b7e7b7af 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def process_clip_state_dict(self, state_dict): k = list(state_dict.keys()) @@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction @@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE): } latent_format = latent_formats.SDXL + memory_usage_factor = 1.0 def get_model(self, state_dict, prefix="", device=None): return model_base.SDXLRefiner(self, device=device) @@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL + memory_usage_factor = 0.7 + def model_type(self, state_dict, prefix=""): if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 self.latent_format = latent_formats.SDXL_Playground_2_5() @@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 + + memory_usage_factor = 1.2 + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): @@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux + + memory_usage_factor = 2.6 + supported_inference_dtypes = [torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] @@ -641,7 +652,13 @@ class Flux(supported_models_base.BASE): return out def clip_target(self, state_dict={}): - return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.FluxClipModel) + pref = self.text_encoder_key_prefix[0] + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: + dtype_t5 = state_dict[t5_key].dtype + else: + dtype_t5 = None + return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(dtype_t5=dtype_t5)) class FluxSchnell(Flux): unet_config = { diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index cf7cdff34..bc0a7e311 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -27,6 +27,8 @@ class BASE: text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + memory_usage_factor = 2.0 + manual_cast_dtype = None @classmethod diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 5a3506df3..8e3f32321 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -56,9 +56,9 @@ class FluxClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] - token_weight_pars_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return t5_out, l_pooled diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index f88b68ae1..8de556d4e 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -87,7 +87,7 @@ class SD3ClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_g = token_weight_pairs["g"] - token_weight_pars_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] lg_out = None pooled = None out = None @@ -114,7 +114,7 @@ class SD3ClipModel(torch.nn.Module): pooled = torch.cat((l_pooled, g_pooled), dim=-1) if self.t5xxl is not None: - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) if lg_out is not None: out = torch.cat([lg_out, t5_out], dim=-2) else: diff --git a/comfy/utils.py b/comfy/utils.py index 010700ad2..2fed7fd3d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -74,9 +74,21 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - params += sd[k].nelement() + w = sd[k] + params += w.nelement() return params +def weight_dtype(sd, prefix=""): + dtypes = {} + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + + if len(dtypes) == 0: + return None + + return max(dtypes, key=dtypes.get) def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: @@ -443,6 +455,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""): return key_map +def flux_to_diffusers(mmdit_config, output_prefix=""): + n_double_layers = mmdit_config.get("depth", 0) + n_single_layers = mmdit_config.get("depth_single_blocks", 0) + hidden_size = mmdit_config.get("hidden_size", 0) + + key_map = {} + for index in range(n_double_layers): + prefix_from = "transformer_blocks.{}".format(index) + prefix_to = "{}double_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.img_attn.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + for index in range(n_single_layers): + prefix_from = "single_transformer_blocks.{}".format(index) + prefix_to = "{}single_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.linear1.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) + + block_map = {#TODO + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + MAP_BASIC = { #TODO + } + + for k in MAP_BASIC: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index d2cb94047..dee9bf1b1 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -3,6 +3,9 @@ import comfy.model_sampling import comfy.latent_formats import torch +from comfy.nodes.common import MAX_RESOLUTION + + class LCM(comfy.model_sampling.EPS): def timestep(self, *args, **kwargs) -> torch.Tensor: pass @@ -173,6 +176,42 @@ class ModelSamplingAuraFlow(ModelSamplingSD3): def patch_aura(self, model, shift): return self.patch(model, shift, multiplier=1.0) +class ModelSamplingFlux: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}), + "width": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, max_shift, base_shift, width, height): + m = model.clone() + + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -289,5 +328,6 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingAuraFlow": ModelSamplingAuraFlow, + "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, } diff --git a/comfy_extras/nodes/nodes_model_merging_model_specific.py b/comfy_extras/nodes/nodes_model_merging_model_specific.py index af636cf72..ef5cb187f 100644 --- a/comfy_extras/nodes/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes/nodes_model_merging_model_specific.py @@ -1,4 +1,6 @@ from . import nodes_model_merging +from .nodes_model_merging import ModelMergeBlocks + class ModelMergeSD1(nodes_model_merging.ModelMergeBlocks): CATEGORY = "advanced/model_merging/model_specific" @@ -75,9 +77,36 @@ class ModelMergeSD3_2B(nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeFlux1(ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["img_in."] = argument + arg_dict["time_in."] = argument + arg_dict["guidance_in"] = argument + arg_dict["vector_in."] = argument + arg_dict["txt_in."] = argument + + for i in range(19): + arg_dict["double_blocks.{}.".format(i)] = argument + + for i in range(38): + arg_dict["single_blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSDXL": ModelMergeSDXL, "ModelMergeSD3_2B": ModelMergeSD3_2B, + "ModelMergeFlux1": ModelMergeFlux1, } diff --git a/comfy_extras/nodes/nodes_sd3.py b/comfy_extras/nodes/nodes_sd3.py index 4d944a959..67db93759 100644 --- a/comfy_extras/nodes/nodes_sd3.py +++ b/comfy_extras/nodes/nodes_sd3.py @@ -33,10 +33,9 @@ class EmptySD3LatentImage: @classmethod def INPUT_TYPES(s): - return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) FUNCTION = "generate"