From 47da42d9283815a58636bd6b42c0434f70b24c9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Aug 2024 17:02:35 -0400 Subject: [PATCH 01/23] Better Flux vram estimation. --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 994b414cc..9989a4c2b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -710,7 +710,7 @@ class Flux(BaseModel): dtype = self.manual_cast_dtype #TODO: this probably needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(dtype) * 0.020) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) * 0.026) * (1024 * 1024) else: area = input_shape[0] * input_shape[2] * input_shape[3] return (area * 0.3) * (1024 * 1024) From 3a9ee995cfbb9425227df9aff534dea12c1af532 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Aug 2024 17:34:30 -0400 Subject: [PATCH 02/23] Tweak regular SD memory formula. --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9989a4c2b..eb2b935df 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -252,7 +252,7 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * math.prod(input_shape[2:]) - return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) * 0.01) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * math.prod(input_shape[2:]) From ea03c9dcd2e2b223c0eb25f24be6b3e1995e2c44 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Aug 2024 18:08:21 -0400 Subject: [PATCH 03/23] Better per model memory usage estimations. --- comfy/model_base.py | 29 ++++------------------------- comfy/supported_models.py | 11 +++++++++++ comfy/supported_models_base.py | 2 ++ 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index eb2b935df..ec15e9fcf 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module): self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) + self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -252,11 +253,11 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * math.prod(input_shape[2:]) - return (area * comfy.model_management.dtype_size(dtype) * 0.01) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * math.prod(input_shape[2:]) - return (area * 0.3) * (1024 * 1024) + return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): @@ -354,6 +355,7 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -594,17 +596,6 @@ class SD3(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def memory_required(self, input_shape): - if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -702,15 +693,3 @@ class Flux(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) return out - - def memory_required(self, input_shape): - if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(dtype) * 0.026) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 43e8f5d1b..681ef95c9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def process_clip_state_dict(self, state_dict): k = list(state_dict.keys()) @@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction @@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE): } latent_format = latent_formats.SDXL + memory_usage_factor = 1.0 def get_model(self, state_dict, prefix="", device=None): return model_base.SDXLRefiner(self, device=device) @@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL + memory_usage_factor = 0.7 + def model_type(self, state_dict, prefix=""): if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 self.latent_format = latent_formats.SDXL_Playground_2_5() @@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 + + memory_usage_factor = 1.2 + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): @@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux + + memory_usage_factor = 2.6 + supported_inference_dtypes = [torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index cf7cdff34..bc0a7e311 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -27,6 +27,8 @@ class BASE: text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + memory_usage_factor = 2.0 + manual_cast_dtype = None @classmethod From 7cd0cdfce601a52c52252ace517b9f52f6237fdb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Aug 2024 23:20:30 -0400 Subject: [PATCH 04/23] Add advanced model merge node for Flux model. --- .../nodes_model_merging_model_specific.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index df111bd60..9557d9b1f 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -75,9 +75,36 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["img_in."] = argument + arg_dict["time_in."] = argument + arg_dict["guidance_in"] = argument + arg_dict["vector_in."] = argument + arg_dict["txt_in."] = argument + + for i in range(19): + arg_dict["double_blocks.{}.".format(i)] = argument + + for i in range(38): + arg_dict["single_blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSDXL": ModelMergeSDXL, "ModelMergeSD3_2B": ModelMergeSD3_2B, + "ModelMergeFlux1": ModelMergeFlux1, } From 0eea47d58086d31695f3e8e9d7ef36c6a6986faa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 03:54:38 -0400 Subject: [PATCH 05/23] Add ModelSamplingFlux to experiment with the shift value. Default shift on Flux Schnell is 0.0 --- comfy_extras/nodes_model_advanced.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 22ba9547b..fef8a4873 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -170,6 +170,33 @@ class ModelSamplingAuraFlow(ModelSamplingSD3): def patch_aura(self, model, shift): return self.patch(model, shift, multiplier=1.0) +class ModelSamplingFlux: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -284,5 +311,6 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingAuraFlow": ModelSamplingAuraFlow, + "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, } From 63a7e8edba76b30e3c01190345126ae75c94777d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 11:53:30 -0400 Subject: [PATCH 06/23] More aggressive batch splitting. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 3f7633814..ce4371d50 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -171,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options): for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: + if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break From f123328b826dcd122d307b75288f89ea301fa25b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 12:39:33 -0400 Subject: [PATCH 07/23] Load T5 in fp8 if it's in fp8 in the Flux checkpoint. --- comfy/supported_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 681ef95c9..94fdcc0d2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -652,7 +652,11 @@ class Flux(supported_models_base.BASE): return out def clip_target(self, state_dict={}): - return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.FluxClipModel) + pref = self.text_encoder_key_prefix[0] + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: + dtype_t5 = state_dict[t5_key].dtype + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)) class FluxSchnell(Flux): unet_config = { From ba9095e5bd7914c2456b2dfe939c06180e97b1ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 13:45:19 -0400 Subject: [PATCH 08/23] Automatically use fp8 for diffusion model weights if: Checkpoint contains weights in fp8. There isn't enough memory to load the diffusion model in GPU vram. --- comfy/model_base.py | 1 + comfy/model_management.py | 22 ++++++++++++++++++++-- comfy/sd.py | 3 ++- comfy/utils.py | 12 +++++++++++- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ec15e9fcf..94f4d333c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module): self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) + logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): diff --git a/comfy/model_management.py b/comfy/model_management.py index da0b989a8..c0fb15095 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev +def maximum_vram_for_weights(device=None): + return (get_total_memory(device) * 0.8 - minimum_inference_memory()) + def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 @@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + + fp8_dtype = None + try: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype in supported_dtypes: + fp8_dtype = dtype + break + except: + pass + + if fp8_dtype is not None: + free_model_memory = maximum_vram_for_weights(device) + if model_params * 2 > free_model_memory: + return fp8_dtype + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 @@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma fp16_works = True if fp16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma bf16_works = torch.cuda.is_bf16_supported() if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True diff --git a/comfy/sd.py b/comfy/sd.py index 41ce18c80..bf336c859 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -510,13 +510,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) + weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index 0db9fbb62..d9fe36f91 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -40,9 +40,19 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - params += sd[k].nelement() + w = sd[k] + params += w.nelement() return params +def weight_dtype(sd, prefix=""): + dtypes = {} + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + + return max(dtypes, key=dtypes.get) + def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: From 1e68002b87a3fb70afc7030c1b4dc6a31fea965e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 14:50:20 -0400 Subject: [PATCH 09/23] Cap lowvram to half of free memory. --- comfy/model_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c0fb15095..2008229f2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required))) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)) + lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory)) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From 2ba5cc8b867bc1aabe59fdaf0a8489e65012d603 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 15:06:40 -0400 Subject: [PATCH 10/23] Fix some issues. --- comfy/model_management.py | 3 +-- comfy/sd.py | 6 +++++- comfy/utils.py | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2008229f2..bb4bcbb21 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,8 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)) - lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory)) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 diff --git a/comfy/sd.py b/comfy/sd.py index bf336c859..fac1a487f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -517,7 +517,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + unet_weight_dtype.append(weight_dtype) + + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index d9fe36f91..06e09170a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -51,6 +51,9 @@ def weight_dtype(sd, prefix=""): w = sd[k] dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + if len(dtypes) == 0: + return None + return max(dtypes, key=dtypes.get) def state_dict_key_replace(state_dict, keys_to_replace): From 03c5018c98b9dd2654dc4942a0978ac53e755900 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 15:14:07 -0400 Subject: [PATCH 11/23] Lower lowvram memory to 1/3 of free memory. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bb4bcbb21..c4402a8a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.33) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From 91be9c2867ef9ae5b255f038665649536c1e1b8b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 16:34:27 -0400 Subject: [PATCH 12/23] Tweak lowvram memory formula. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c4402a8a7..b280b149d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.33) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From f7a5107784cded39f92a4bb7553507575e78edbe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 16:55:38 -0400 Subject: [PATCH 13/23] Fix crash. --- comfy/model_management.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b280b149d..fb2747015 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -928,10 +928,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if device is None: - device = torch.device("cuda") - - props = torch.cuda.get_device_properties(device) + props = torch.cuda.get_device_properties("cuda") if props.major >= 8: return True From 56f3c660bf79769bbfa003c0e4152dfb50feadc5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Aug 2024 04:06:00 -0400 Subject: [PATCH 14/23] ModelSamplingFlux now takes a resolution and adjusts the shift with it. If you want to sample Flux dev exactly how the reference code does use the same resolution as your image in this node. --- comfy_extras/nodes_model_advanced.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index fef8a4873..918e6085a 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -2,6 +2,7 @@ import folder_paths import comfy.sd import comfy.model_sampling import comfy.latent_formats +import nodes import torch class LCM(comfy.model_sampling.EPS): @@ -174,7 +175,10 @@ class ModelSamplingFlux: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), }} RETURN_TYPES = ("MODEL",) @@ -182,9 +186,15 @@ class ModelSamplingFlux: CATEGORY = "advanced/model" - def patch(self, model, shift): + def patch(self, model, max_shift, base_shift, width, height): m = model.clone() + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + sampling_base = comfy.model_sampling.ModelSamplingFlux sampling_type = comfy.model_sampling.CONST From 0a6b0081176c6233015ec00d004c534c088ddcb0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Aug 2024 10:03:33 -0400 Subject: [PATCH 15/23] Fix issue with some custom nodes. --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 94f4d333c..d19f5697a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -84,6 +84,7 @@ class BaseModel(torch.nn.Module): if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") + logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -94,7 +95,6 @@ class BaseModel(torch.nn.Module): self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) - logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): From 3b71f84b5051905be8f3311abeb39d725743d15b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Aug 2024 15:45:43 -0400 Subject: [PATCH 16/23] ONNX tracing fixes. --- comfy/ldm/aura/mmdit.py | 6 ++---- comfy/ldm/common_dit.py | 8 ++++++++ comfy/ldm/flux/model.py | 8 +++----- comfy/ldm/modules/diffusionmodules/mmdit.py | 5 ++--- comfy/model_detection.py | 2 +- 5 files changed, 16 insertions(+), 13 deletions(-) create mode 100644 comfy/ldm/common_dit.py diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 9956d3638..cd9a42185 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops +import comfy.ldm.common_dit def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -407,10 +408,7 @@ class MMDiT(nn.Module): def patchify(self, x): B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) % self.patch_size - pad_w = (self.patch_size - W % self.patch_size) % self.patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) x = x.view( B, C, diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py new file mode 100644 index 000000000..990025521 --- /dev/null +++ b/comfy/ldm/common_dit.py @@ -0,0 +1,8 @@ +import torch + +def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index e7931c16d..db6cf3d22 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( ) from einops import rearrange, repeat +import comfy.ldm.common_dit @dataclass class FluxParams: @@ -42,7 +43,7 @@ class Flux(nn.Module): self.dtype = dtype params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels + self.in_channels = params.in_channels * 2 * 2 self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( @@ -125,10 +126,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, **kwargs): bs, c, h, w = x.shape patch_size = 2 - pad_h = (patch_size - h % 2) % patch_size - pad_w = (patch_size - w % 2) % patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index ea1b5aa05..491a58a20 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -9,6 +9,7 @@ from .. import attention from einops import rearrange, repeat from .util import timestep_embedding import comfy.ops +import comfy.ldm.common_dit def default(x, y): if x is not None: @@ -111,9 +112,7 @@ class PatchEmbed(nn.Module): # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." # ) if self.dynamic_img_pad: - pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] - pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode) + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dda9797b7..c47119686 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -131,7 +131,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" - dit_config["in_channels"] = 64 + dit_config["in_channels"] = 16 dit_config["vec_in_dim"] = 768 dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 From ddb6a9f47cd2e680aa821f320d52e909f0a03fc3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Aug 2024 15:59:02 -0400 Subject: [PATCH 17/23] Set the step in EmptySD3LatentImage to 16. These models work better when the res is a multiple of 16. --- comfy_extras/nodes_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 0aafa2426..ae9b85981 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -27,8 +27,8 @@ class EmptySD3LatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" From 7afa985fbafc15b2b603a4428917f4a600560699 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sun, 4 Aug 2024 23:10:02 +0200 Subject: [PATCH 18/23] Correct spelling 'token_weight_pars_t5' to 'token_weight_pairs_t5' (#4200) --- comfy/text_encoders/flux.py | 4 ++-- comfy/text_encoders/sd3_clip.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 0590741bb..ee26f560d 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -52,9 +52,9 @@ class FluxClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] - token_weight_pars_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return t5_out, l_pooled diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 143d884cb..549c068e9 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -81,7 +81,7 @@ class SD3ClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_g = token_weight_pairs["g"] - token_weight_pars_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] lg_out = None pooled = None out = None @@ -108,7 +108,7 @@ class SD3ClipModel(torch.nn.Module): pooled = torch.cat((l_pooled, g_pooled), dim=-1) if self.t5xxl is not None: - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) if lg_out is not None: out = torch.cat([lg_out, t5_out], dim=-2) else: From 78e133d0415784924cd2674e2ee48f3eeca8a2aa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Aug 2024 21:59:42 -0400 Subject: [PATCH 19/23] Support simple diffusers Flux loras. --- comfy/lora.py | 8 ++++++++ comfy/utils.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index fdc128c09..04e8861c9 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -288,4 +288,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format + if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux + diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format + key_map[key_lora] = to + return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 06e09170a..ec7d36607 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -415,6 +415,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""): return key_map +def flux_to_diffusers(mmdit_config, output_prefix=""): + n_double_layers = mmdit_config.get("depth", 0) + n_single_layers = mmdit_config.get("depth_single_blocks", 0) + hidden_size = mmdit_config.get("hidden_size", 0) + + key_map = {} + for index in range(n_double_layers): + prefix_from = "transformer_blocks.{}".format(index) + prefix_to = "{}double_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.img_attn.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + for index in range(n_single_layers): + prefix_from = "single_transformer_blocks.{}".format(index) + prefix_to = "{}single_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.linear1.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) + + block_map = {#TODO + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + MAP_BASIC = { #TODO + } + + for k in MAP_BASIC: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) From a178e25912b01abf436eba1cfaab316ba02d272d Mon Sep 17 00:00:00 2001 From: a-One-Fan <100067309+a-One-Fan@users.noreply.github.com> Date: Mon, 5 Aug 2024 08:26:20 +0300 Subject: [PATCH 20/23] Fix Flux FP64 math on XPU (#4210) --- comfy/ldm/flux/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 88c2b6bb4..136ce2aa8 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -14,7 +14,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device): + if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu(): device = torch.device("cpu") else: device = pos.device From 33e5203a2a7bc90dc4c6577ed645456abc530155 Mon Sep 17 00:00:00 2001 From: bymyself Date: Mon, 5 Aug 2024 09:25:28 -0700 Subject: [PATCH 21/23] Don't cache index.html (#4211) --- server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 23ca2fd33..1c07f9781 100644 --- a/server.py +++ b/server.py @@ -127,7 +127,11 @@ class PromptServer(): @routes.get("/") async def get_root(request): - return web.FileResponse(os.path.join(self.web_root, "index.html")) + response = web.FileResponse(os.path.join(self.web_root, "index.html")) + response.headers['Cache-Control'] = 'no-cache' + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + return response @routes.get("/embeddings") def get_embeddings(self): From e545a636baae052000abb1250a69e1cac32b2bae Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 5 Aug 2024 12:31:12 -0400 Subject: [PATCH 22/23] This probably doesn't work anymore. --- README.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/README.md b/README.md index a542ed4d6..d5ded7297 100644 --- a/README.md +++ b/README.md @@ -165,20 +165,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` -### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? - -You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: - -```source path_to_other_sd_gui/venv/bin/activate``` - -or on Windows: - -With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"``` - -With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"``` - -And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. - # Running ```python main.py``` From 8edbcf520900112d4e11f510ba33949503b58f51 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 5 Aug 2024 16:24:04 -0400 Subject: [PATCH 23/23] Improve performance on some lowend GPUs. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index fb2747015..3d9ed5251 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -528,7 +528,7 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev def maximum_vram_for_weights(device=None): - return (get_total_memory(device) * 0.8 - minimum_inference_memory()) + return (get_total_memory(device) * 0.88 - minimum_inference_memory()) def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: