Compare commits

...

3 Commits

Author SHA1 Message Date
Dave Lage
3ac985d5c5
Merge 11a182ef10 into c4a14df9a3 2026-01-21 09:10:06 +08:00
rockerBOO
11a182ef10
Update half_eps 2025-03-17 12:09:01 -04:00
rockerBOO
8a50599bd8
Fix RenormCFG node for batches 2025-03-17 01:03:13 -04:00

View File

@ -40,12 +40,13 @@ class RenormCFG(io.ComfyNode):
ori_pos_norm = torch.linalg.vector_norm(cond_eps
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
max_new_norms = ori_pos_norm * float(renorm_cfg)
new_pos_norms = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
for i, (max_new_norm, new_pos_norm) in enumerate(zip(max_new_norms, new_pos_norms)):
if new_pos_norm >= max_new_norm:
half_eps[i] = half_eps[i] * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]