From 858d51f91a6039387d749c107d15ee9690cb7f1b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Aug 2024 04:08:54 -0400 Subject: [PATCH 1/4] Fix VAEDecode -> Preview not being executed first. --- comfy_execution/graph.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 303ccae31..2c7e8f8ae 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -176,19 +176,35 @@ class ExecutionList(TopologicalSort): "current_inputs": [] } return None, error_details, ex - next_node = available[0] + + self.staged_node_id = self.ux_friendly_pick_node(available) + return self.staged_node_id, None, None + + def ux_friendly_pick_node(self, node_list): # If an output node is available, do that first. # Technically this has no effect on the overall length of execution, but it feels better as a user # for a PreviewImage to display a result as soon as it can # Some other heuristics could probably be used here to improve the UX further. - for node_id in available: + def is_output(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - next_node = node_id - break - self.staged_node_id = next_node - return self.staged_node_id, None, None + return True + return False + + for node_id in node_list: + if is_output(node_id): + return node_id + + #This should handle the VAEDecode -> preview case + for node_id in node_list: + for blocked_node_id in self.blocking[node_id]: + if is_output(blocked_node_id): + return node_id + + #Do we want to look deeper? + + return node_list[0] def unstage_node_execution(self): assert self.staged_node_id is not None From fca42836f26d06b14a3149c0c94fc1c7f264f633 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Aug 2024 10:15:13 -0400 Subject: [PATCH 2/4] Add model_options for text encoder. --- comfy/sd.py | 21 +++++++++++++-------- comfy/sd1_clip.py | 12 ++++++++---- comfy/sdxl_clip.py | 22 +++++++++++----------- comfy/text_encoders/aura_t5.py | 8 ++++---- comfy/text_encoders/flux.py | 14 +++++++------- comfy/text_encoders/hydit.py | 14 +++++++------- comfy/text_encoders/sa_t5.py | 8 ++++---- comfy/text_encoders/sd2_clip.py | 8 ++++---- comfy/text_encoders/sd3_clip.py | 16 ++++++++-------- 9 files changed, 66 insertions(+), 57 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index edd0b51d8..cae7812e3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): if no_init: return params = target.params.copy() @@ -71,9 +71,14 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - dtype = model_management.text_encoder_dtype(load_device) + dtype = model_options.get("dtype", None) + if dtype is None: + dtype = model_management.text_encoder_dtype(load_device) + params['dtype'] = dtype params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) + params['model_options'] = model_options + self.cond_stage_model = clip(**(params)) for dt in self.cond_stage_model.dtypes: @@ -394,7 +399,7 @@ class CLIPType(Enum): HUNYUAN_DIT = 5 FLUX = 6 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -464,7 +469,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI for c in clip_data: parameters += comfy.utils.calculate_parameters(c) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: @@ -506,14 +511,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): clip = None clipvision = None vae = None @@ -563,7 +568,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index e65cab285..dc8413b7b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,7 +84,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, - return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32 + return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -94,7 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) - self.operations = comfy.ops.manual_cast + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.manual_cast + + self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) self.num_layers = self.transformer.num_layers @@ -553,7 +557,7 @@ class SD1Tokenizer: return {} class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): + def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): super().__init__() if name is not None: @@ -563,7 +567,7 @@ class SD1ClipModel(torch.nn.Module): self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) - setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() if dtype is not None: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 6e6b87d62..860900ccd 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,14 +3,14 @@ import torch import os class SDXLClipG(sd1_clip.SDClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) @@ -38,10 +38,10 @@ class SDXLTokenizer: return {} class SDXLClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) - self.clip_g = SDXLClipG(device=device, dtype=dtype) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) def set_clip_options(self, options): @@ -66,8 +66,8 @@ class SDXLClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options) class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): @@ -79,14 +79,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer) class StableCascadeClipG(sd1_clip.SDClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) class StableCascadeClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options) diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index 8500c7b70..e9ad45a7f 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -4,9 +4,9 @@ import comfy.text_encoders.t5 import os class PT5XlModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options) class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) class AuraT5Model(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index ee26f560d..91d1f249d 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -6,9 +6,9 @@ import torch import os class T5XXLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -35,11 +35,11 @@ class FluxTokenizer: class FluxClipModel(torch.nn.Module): - def __init__(self, dtype_t5=None, device="cpu", dtype=None): + def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False) - self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) def set_clip_options(self, options): @@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module): def flux_clip(dtype_t5=None): class FluxClipModel_(FluxClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype) + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 9dfa288b1..7cb790f45 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -7,9 +7,9 @@ import os import torch class HyditBertModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer): class MT5XLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -50,10 +50,10 @@ class HyditTokenizer: return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]} class HyditModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - self.hydit_clip = HyditBertModel(dtype=dtype) - self.mt5xl = MT5XLModel(dtype=dtype) + self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options) + self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options) self.dtypes = set() if dtype is not None: diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 189f8c181..7778ce47a 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -4,9 +4,9 @@ import comfy.text_encoders.t5 import os class T5BaseModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer) class SAT5Model(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs) + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs) diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 8ea8e1cd0..0c98cd853 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -2,13 +2,13 @@ from comfy import sd1_clip import os class SD2ClipHModel(sd1_clip.SDClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, model_options=model_options) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): @@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer) class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 549c068e9..83e8fa1f3 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -8,9 +8,9 @@ import comfy.model_management import logging class T5XXLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -38,24 +38,24 @@ class SD3Tokenizer: return {} class SD3ClipModel(torch.nn.Module): - def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() self.dtypes = set() if clip_l: - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None if clip_g: - self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes.add(dtype) else: self.clip_g = None if t5: dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes.add(dtype_t5) else: self.t5xxl = None @@ -132,6 +132,6 @@ class SD3ClipModel(torch.nn.Module): def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): class SD3ClipModel_(SD3ClipModel): - def __init__(self, device="cpu", dtype=None): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ From 14af129c5509d10504113a1520c45b0ebcf81f14 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Aug 2024 11:36:10 -0400 Subject: [PATCH 3/4] Improve execution UX. Some branches with VAELoader -> VAEDecode -> Preview were being executed last. With this change they will be executed earlier. --- comfy_execution/graph.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 2c7e8f8ae..b53e10f3f 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -202,8 +202,14 @@ class ExecutionList(TopologicalSort): if is_output(blocked_node_id): return node_id - #Do we want to look deeper? + #This should handle the VAELoader -> VAEDecode -> preview case + for node_id in node_list: + for blocked_node_id in self.blocking[node_id]: + for blocked_node_id1 in self.blocking[blocked_node_id]: + if is_output(blocked_node_id1): + return node_id + #TODO: this function should be improved return node_list[0] def unstage_node_execution(self): From bb222ceddb232aafafa99cd4dec38b3719c29d7d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Aug 2024 14:07:19 -0400 Subject: [PATCH 4/4] Fix loras having a weak effect when applied on fp8. --- comfy/float.py | 51 ++++++++++++++++++++++++++++++++++++++++++ comfy/model_patcher.py | 46 ++++++++++++++++++++++++------------- 2 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 comfy/float.py diff --git a/comfy/float.py b/comfy/float.py new file mode 100644 index 000000000..9822ae482 --- /dev/null +++ b/comfy/float.py @@ -0,0 +1,51 @@ +import torch + +#Not 100% sure about this +def manual_stochastic_round_to_float8(x, dtype): + if dtype == torch.float8_e4m3fn: + EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 + elif dtype == torch.float8_e5m2: + EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 + else: + raise ValueError("Unsupported dtype") + + sign = torch.sign(x) + abs_x = x.abs() + + # Combine exponent calculation and clamping + exponent = torch.clamp( + torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS, + 0, 2**EXPONENT_BITS - 1 + ) + + # Combine mantissa calculation and rounding + mantissa = abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0 + mantissa_scaled = mantissa * (2**MANTISSA_BITS) + mantissa_floor = mantissa_scaled.floor() + mantissa = torch.where( + torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), + (mantissa_floor + 1) / (2**MANTISSA_BITS), + mantissa_floor / (2**MANTISSA_BITS) + ) + + # Combine final result calculation + result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa) + + # Handle zero case + result = torch.where(abs_x == 0, torch.zeros_like(result), result) + + return result.to(dtype=dtype) + + + +def stochastic_rounding(value, dtype): + if dtype == torch.float32: + return value.to(dtype=torch.float32) + if dtype == torch.float16: + return value.to(dtype=torch.float16) + if dtype == torch.bfloat16: + return value.to(dtype=torch.bfloat16) + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + return manual_stochastic_round_to_float8(value, dtype) + + return value.to(dtype=dtype) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 837c64b05..c6fb0eff1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -22,8 +22,10 @@ import inspect import logging import uuid import collections +import math import comfy.utils +import comfy.float import comfy.model_management from comfy.types import UnetWrapperFunction @@ -327,7 +329,8 @@ class ModelPatcher: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: @@ -341,12 +344,16 @@ class ModelPatcher: if patch_weights: model_sd = self.model_state_dict() + keys_sort = [] for key in self.patches: if key not in model_sd: logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue + keys_sort.append((math.prod(model_sd[key].shape), key)) - self.patch_weight_to_device(key, device_to) + keys_sort.sort(reverse=True) + for ks in keys_sort: + self.patch_weight_to_device(ks[1], device_to) if device_to is not None: self.model.to(device_to) @@ -359,6 +366,7 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + load_completely = [] for n, m in self.model.named_modules(): lowvram_weight = False @@ -395,20 +403,28 @@ class ModelPatcher: wipe_lowvram_weight(m) if hasattr(m, "weight"): - mem_counter += comfy.model_management.module_size(m) - param = list(m.parameters()) - if len(param) > 0: - weight = param[0] - if weight.device == device_to: - continue + mem_used = comfy.model_management.module_size(m) + mem_counter += mem_used + load_completely.append((mem_used, n, m)) - weight_to = None - if full_load:#TODO - weight_to = device_to - self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM - self.patch_weight_to_device(bias_key, device_to=weight_to) - m.to(device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + param = list(m.parameters()) + if len(param) > 0: + weight = param[0] + if weight.device == device_to: + continue + + self.patch_weight_to_device(weight_key, device_to=device_to) + self.patch_weight_to_device(bias_key, device_to=device_to) + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + + for x in load_completely: + x[2].to(device_to) if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))