neg condition was passed as context

This commit is contained in:
Yousef R. Gamaleldin 2026-01-16 22:17:33 +02:00
parent bc4fd2cd11
commit c78bfda132
3 changed files with 49 additions and 46 deletions

View File

@ -1353,7 +1353,8 @@ class NaDiT(nn.Module):
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
txt, txt_shape = flatten([pos_cond, neg_cond])
except:
txt, txt_shape = flatten([context.squeeze(0)])
context = self.positive_conditioning
txt, txt_shape = flatten([context])
vid, vid_shape = flatten(x)
cond_latent, _ = flatten(conditions)

View File

@ -1628,20 +1628,20 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
@ -1651,27 +1651,27 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
@ -1689,28 +1689,28 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(0.95047)
# y *= 1.00000 # (no-op, skip)
z.mul_(1.08883)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
@ -1719,7 +1719,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
@ -1732,25 +1732,25 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(0.95047) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(1.08883) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
@ -1760,13 +1760,13 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
@ -1775,7 +1775,7 @@ def lab_color_transfer(
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
@ -1783,45 +1783,45 @@ def lab_color_transfer(
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = 6.0 / 29.0
kappa = (29.0 / 3.0) ** 3
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
@ -1832,23 +1832,23 @@ def lab_color_transfer(
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
class VideoAutoencoderKL(nn.Module):
@ -2140,10 +2140,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
input = input.to(x.device)
o_h, o_w = self.img_dims
x = x[..., :o_h, :o_w]
input = input[..., :o_h, :o_w ]
x = lab_color_transfer(x, input)
x = x.unsqueeze(0)
x = x[..., :o_h, :o_w]
x = rearrange(x, "b t c h w -> b c t h w")
# ensure even dims for save video
@ -2154,8 +2155,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return x
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"):
set_norm_limit(norm_max_mem)
for m in self.modules():
if isinstance(m, InflatedCausalConv3d):
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
for module in self.modules():
if isinstance(module, InflatedCausalConv3d):
module.set_memory_device(memory_device)

View File

@ -1307,9 +1307,6 @@ class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
sampling_settings = {
"shift": 1.0,
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]