mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
neg condition was passed as context
This commit is contained in:
parent
bc4fd2cd11
commit
c78bfda132
@ -1353,7 +1353,8 @@ class NaDiT(nn.Module):
|
||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||
txt, txt_shape = flatten([pos_cond, neg_cond])
|
||||
except:
|
||||
txt, txt_shape = flatten([context.squeeze(0)])
|
||||
context = self.positive_conditioning
|
||||
txt, txt_shape = flatten([context])
|
||||
|
||||
vid, vid_shape = flatten(x)
|
||||
cond_latent, _ = flatten(conditions)
|
||||
|
||||
@ -1628,20 +1628,20 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
@ -1651,27 +1651,27 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
@ -1689,28 +1689,28 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(0.95047)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(1.08883)
|
||||
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
@ -1719,7 +1719,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
@ -1732,25 +1732,25 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(0.95047) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(1.08883) # Z
|
||||
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
@ -1760,13 +1760,13 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
@ -1775,7 +1775,7 @@ def lab_color_transfer(
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
@ -1783,45 +1783,45 @@ def lab_color_transfer(
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
epsilon = 6.0 / 29.0
|
||||
kappa = (29.0 / 3.0) ** 3
|
||||
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
@ -1832,23 +1832,23 @@ def lab_color_transfer(
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
@ -2140,10 +2140,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
|
||||
input = input.to(x.device)
|
||||
o_h, o_w = self.img_dims
|
||||
x = x[..., :o_h, :o_w]
|
||||
input = input[..., :o_h, :o_w ]
|
||||
x = lab_color_transfer(x, input)
|
||||
|
||||
x = x.unsqueeze(0)
|
||||
x = x[..., :o_h, :o_w]
|
||||
x = rearrange(x, "b t c h w -> b c t h w")
|
||||
|
||||
# ensure even dims for save video
|
||||
@ -2154,8 +2155,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
|
||||
return x
|
||||
|
||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"):
|
||||
set_norm_limit(norm_max_mem)
|
||||
for m in self.modules():
|
||||
if isinstance(m, InflatedCausalConv3d):
|
||||
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, InflatedCausalConv3d):
|
||||
module.set_memory_device(memory_device)
|
||||
|
||||
@ -1307,9 +1307,6 @@ class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user