diff --git a/comfy/ops.py b/comfy/ops.py index 55e958adb..9d7dedd37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -365,12 +365,13 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - try: - out = fp8_linear(self, input) - if out is not None: - return out - except Exception as e: - logging.info("Exception during fp8 op: {}".format(e)) + if not self.training: + try: + out = fp8_linear(self, input) + if out is not None: + return out + except Exception as e: + logging.info("Exception during fp8 op: {}".format(e)) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 55c97a3af..0abb2d403 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat1, 0.1) torch.nn.init.constant_(mat2, 0.0) - mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) return LohaDiff( diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 563c835f5..9b2aff2d7 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase): in_dim = weight.shape[1:].numel() out1, out2 = factorization(out_dim, rank) in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) return LokrDiff( diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 47aa17d13..4db004e50 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) return LoraDiff( diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 9d4982083..c0aab9635 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) return OFTDiff( (block, None, alpha, None) ) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 98dbbf102..255ac420d 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -5,19 +5,30 @@ import torch class DifferentialDiffusion(): @classmethod def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - }} + return { + "required": { + "model": ("MODEL", ), + }, + "optional": { + "strength": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "apply" CATEGORY = "_for_testing" INIT = False - def apply(self, model): + def apply(self, model, strength=1.0): model = model.clone() - model.set_model_denoise_mask_function(self.forward) - return (model,) + model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) + return (model, ) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -31,7 +42,15 @@ class DifferentialDiffusion(): threshold = (current_ts - ts_to) / (ts_from - ts_to) - return (denoise_mask >= threshold).to(denoise_mask.dtype) + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index c3aaaee9b..9e6ec6780 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): return new_dict +def process_cond_list(d, prefix=""): + if hasattr(d, "__iter__") and not hasattr(d, "items"): + for index, item in enumerate(d): + process_cond_list(item, f"{prefix}.{index}") + return d + elif hasattr(d, "items"): + for k, v in list(d.items()): + if isinstance(v, dict): + process_cond_list(v, f"{prefix}.{k}") + elif isinstance(v, torch.Tensor): + d[k] = v.clone() + elif isinstance(v, (list, tuple)): + for index, item in enumerate(v): + process_cond_list(item, f"{prefix}.{k}.{index}") + return d + + class TrainSampler(comfy.samplers.Sampler): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn @@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler): self.training_dtype = training_dtype def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache()