diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index a653e02fc..1509de2f8 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,6 +387,9 @@ class Kandinsky5(nn.Module): return self.out_layer(visual_embed, time_embed) def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + original_dims = x.ndim + if original_dims == 4: + x = x.unsqueeze(2) bs, c, t_len, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) @@ -397,7 +400,10 @@ class Kandinsky5(nn.Module): freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) - return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + if original_dims == 4: + out = out.squeeze(2) + return out def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( diff --git a/comfy/model_management.py b/comfy/model_management.py index eb63312d4..58ae13f28 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1493,6 +1493,20 @@ def extended_fp16_support(): return True +LORA_COMPUTE_DTYPES = {} +def lora_compute_dtype(device): + dtype = LORA_COMPUTE_DTYPES.get(device, None) + if dtype is not None: + return dtype + + if should_use_fp16(device): + dtype = torch.float16 + else: + dtype = torch.float32 + + LORA_COMPUTE_DTYPES[device] = dtype + return dtype + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 215784874..5b1ccb824 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -614,10 +614,11 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: - temp_weight = weight.to(torch.float32, copy=True) + temp_weight = weight.to(temp_dtype, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) @@ -761,6 +762,8 @@ class ModelPatcher: key = "{}.{}".format(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) + if comfy.model_management.is_device_cuda(device_to): + torch.cuda.synchronize() logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index 22f991c36..be086458c 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -24,10 +24,10 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -56,12 +56,12 @@ class Kandinsky5TEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Kandinsky5TEModel_(Kandinsky5TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options)