Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-09-21 10:45:28 +09:00
commit 4ea946778b
7 changed files with 60 additions and 22 deletions

View File

@ -365,6 +365,7 @@ class fp8_ops(manual_cast):
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if not self.training:
try: try:
out = fp8_linear(self, input) out = fp8_linear(self, input)
if out is not None: if out is not None:

View File

@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat1, 0.1) torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0) torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01) torch.nn.init.normal_(mat4, 0.01)
return LohaDiff( return LohaDiff(

View File

@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase):
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank) out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank) in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0) torch.nn.init.constant_(mat1, 0.0)
return LokrDiff( return LokrDiff(

View File

@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0) torch.nn.init.constant_(mat2, 0.0)
return LoraDiff( return LoraDiff(

View File

@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank) block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff( return OFTDiff(
(block, None, alpha, None) (block, None, alpha, None)
) )

View File

@ -5,19 +5,30 @@ import torch
class DifferentialDiffusion(): class DifferentialDiffusion():
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ), return {
}} "required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "apply" FUNCTION = "apply"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
INIT = False INIT = False
def apply(self, model): def apply(self, model, strength=1.0):
model = model.clone() model = model.clone()
model.set_model_denoise_mask_function(self.forward) model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model,) return (model, )
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"] model = extra_options["model"]
step_sigmas = extra_options["sigmas"] step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min sigma_to = model.inner_model.model_sampling.sigma_min
@ -31,7 +42,15 @@ class DifferentialDiffusion():
threshold = (current_ts - ts_to) / (ts_from - ts_to) threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype) # Generate the binary mask based on the threshold
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
# Blend binary mask with the original denoise_mask using strength
if strength and strength < 1:
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
return blended_mask
else:
return binary_mask
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
return new_dict return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn self.loss_fn = loss_fn
@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"] cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0) dataset_size = sigmas.size(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()