mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 08:10:21 +08:00
neg condition was passed as context
This commit is contained in:
parent
bc4fd2cd11
commit
c78bfda132
@ -1353,7 +1353,8 @@ class NaDiT(nn.Module):
|
|||||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||||
txt, txt_shape = flatten([pos_cond, neg_cond])
|
txt, txt_shape = flatten([pos_cond, neg_cond])
|
||||||
except:
|
except:
|
||||||
txt, txt_shape = flatten([context.squeeze(0)])
|
context = self.positive_conditioning
|
||||||
|
txt, txt_shape = flatten([context])
|
||||||
|
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
cond_latent, _ = flatten(conditions)
|
cond_latent, _ = flatten(conditions)
|
||||||
|
|||||||
@ -1628,20 +1628,20 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
|||||||
|
|
||||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||||
original_shape = source.shape
|
original_shape = source.shape
|
||||||
|
|
||||||
# Flatten
|
# Flatten
|
||||||
source_flat = source.flatten()
|
source_flat = source.flatten()
|
||||||
reference_flat = reference.flatten()
|
reference_flat = reference.flatten()
|
||||||
|
|
||||||
# Sort both arrays
|
# Sort both arrays
|
||||||
source_sorted, source_indices = torch.sort(source_flat)
|
source_sorted, source_indices = torch.sort(source_flat)
|
||||||
reference_sorted, _ = torch.sort(reference_flat)
|
reference_sorted, _ = torch.sort(reference_flat)
|
||||||
del reference_flat
|
del reference_flat
|
||||||
|
|
||||||
# Quantile mapping
|
# Quantile mapping
|
||||||
n_source = len(source_sorted)
|
n_source = len(source_sorted)
|
||||||
n_reference = len(reference_sorted)
|
n_reference = len(reference_sorted)
|
||||||
|
|
||||||
if n_source == n_reference:
|
if n_source == n_reference:
|
||||||
matched_sorted = reference_sorted
|
matched_sorted = reference_sorted
|
||||||
else:
|
else:
|
||||||
@ -1651,27 +1651,27 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
|
|||||||
ref_indices.clamp_(0, n_reference - 1)
|
ref_indices.clamp_(0, n_reference - 1)
|
||||||
matched_sorted = reference_sorted[ref_indices]
|
matched_sorted = reference_sorted[ref_indices]
|
||||||
del source_quantiles, ref_indices, reference_sorted
|
del source_quantiles, ref_indices, reference_sorted
|
||||||
|
|
||||||
del source_sorted, source_flat
|
del source_sorted, source_flat
|
||||||
|
|
||||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||||
inverse_indices = torch.argsort(source_indices)
|
inverse_indices = torch.argsort(source_indices)
|
||||||
del source_indices
|
del source_indices
|
||||||
matched_flat = matched_sorted[inverse_indices]
|
matched_flat = matched_sorted[inverse_indices]
|
||||||
del matched_sorted, inverse_indices
|
del matched_sorted, inverse_indices
|
||||||
|
|
||||||
return matched_flat.reshape(original_shape)
|
return matched_flat.reshape(original_shape)
|
||||||
|
|
||||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
"""Convert batch of CIELAB images to RGB color space."""
|
"""Convert batch of CIELAB images to RGB color space."""
|
||||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||||
|
|
||||||
# LAB to XYZ
|
# LAB to XYZ
|
||||||
fy = (L + 16.0) / 116.0
|
fy = (L + 16.0) / 116.0
|
||||||
fx = a.div(500.0).add_(fy)
|
fx = a.div(500.0).add_(fy)
|
||||||
fz = fy - b / 200.0
|
fz = fy - b / 200.0
|
||||||
del L, a, b
|
del L, a, b
|
||||||
|
|
||||||
# XYZ transformation
|
# XYZ transformation
|
||||||
x = torch.where(
|
x = torch.where(
|
||||||
fx > epsilon,
|
fx > epsilon,
|
||||||
@ -1689,28 +1689,28 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
|||||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
)
|
)
|
||||||
del fx, fy, fz
|
del fx, fy, fz
|
||||||
|
|
||||||
# Apply D65 white point (in-place)
|
# Apply D65 white point (in-place)
|
||||||
x.mul_(0.95047)
|
x.mul_(0.95047)
|
||||||
# y *= 1.00000 # (no-op, skip)
|
# y *= 1.00000 # (no-op, skip)
|
||||||
z.mul_(1.08883)
|
z.mul_(1.08883)
|
||||||
|
|
||||||
xyz = torch.stack([x, y, z], dim=1)
|
xyz = torch.stack([x, y, z], dim=1)
|
||||||
del x, y, z
|
del x, y, z
|
||||||
|
|
||||||
# Matrix multiplication: XYZ -> RGB
|
# Matrix multiplication: XYZ -> RGB
|
||||||
B, C, H, W = xyz.shape
|
B, C, H, W = xyz.shape
|
||||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
del xyz
|
del xyz
|
||||||
|
|
||||||
# Ensure dtype consistency for matrix multiplication
|
# Ensure dtype consistency for matrix multiplication
|
||||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||||
del xyz_flat
|
del xyz_flat
|
||||||
|
|
||||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
del rgb_linear_flat
|
del rgb_linear_flat
|
||||||
|
|
||||||
# Apply inverse gamma correction (delinearize)
|
# Apply inverse gamma correction (delinearize)
|
||||||
mask = rgb_linear > 0.0031308
|
mask = rgb_linear > 0.0031308
|
||||||
rgb = torch.where(
|
rgb = torch.where(
|
||||||
@ -1719,7 +1719,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
|||||||
rgb_linear * 12.92
|
rgb_linear * 12.92
|
||||||
)
|
)
|
||||||
del mask, rgb_linear
|
del mask, rgb_linear
|
||||||
|
|
||||||
return torch.clamp(rgb, 0.0, 1.0)
|
return torch.clamp(rgb, 0.0, 1.0)
|
||||||
|
|
||||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
@ -1732,25 +1732,25 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
|||||||
rgb / 12.92
|
rgb / 12.92
|
||||||
)
|
)
|
||||||
del mask
|
del mask
|
||||||
|
|
||||||
# Matrix multiplication: RGB -> XYZ
|
# Matrix multiplication: RGB -> XYZ
|
||||||
B, C, H, W = rgb_linear.shape
|
B, C, H, W = rgb_linear.shape
|
||||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
del rgb_linear
|
del rgb_linear
|
||||||
|
|
||||||
# Ensure dtype consistency for matrix multiplication
|
# Ensure dtype consistency for matrix multiplication
|
||||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||||
del rgb_flat
|
del rgb_flat
|
||||||
|
|
||||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
del xyz_flat
|
del xyz_flat
|
||||||
|
|
||||||
# Normalize by D65 white point (in-place)
|
# Normalize by D65 white point (in-place)
|
||||||
xyz[:, 0].div_(0.95047) # X
|
xyz[:, 0].div_(0.95047) # X
|
||||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||||
xyz[:, 2].div_(1.08883) # Z
|
xyz[:, 2].div_(1.08883) # Z
|
||||||
|
|
||||||
# XYZ to LAB transformation
|
# XYZ to LAB transformation
|
||||||
epsilon_cubed = epsilon ** 3
|
epsilon_cubed = epsilon ** 3
|
||||||
mask = xyz > epsilon_cubed
|
mask = xyz > epsilon_cubed
|
||||||
@ -1760,13 +1760,13 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
|||||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||||
)
|
)
|
||||||
del xyz, mask
|
del xyz, mask
|
||||||
|
|
||||||
# Extract channels and compute LAB
|
# Extract channels and compute LAB
|
||||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||||
del f_xyz
|
del f_xyz
|
||||||
|
|
||||||
return torch.stack([L, a, b], dim=1)
|
return torch.stack([L, a, b], dim=1)
|
||||||
|
|
||||||
def lab_color_transfer(
|
def lab_color_transfer(
|
||||||
@ -1775,7 +1775,7 @@ def lab_color_transfer(
|
|||||||
luminance_weight: float = 0.8
|
luminance_weight: float = 0.8
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||||
|
|
||||||
if content_feat.shape != style_feat.shape:
|
if content_feat.shape != style_feat.shape:
|
||||||
style_feat = safe_interpolate_operation(
|
style_feat = safe_interpolate_operation(
|
||||||
style_feat,
|
style_feat,
|
||||||
@ -1783,45 +1783,45 @@ def lab_color_transfer(
|
|||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
|
|
||||||
device = content_feat.device
|
device = content_feat.device
|
||||||
|
|
||||||
def ensure_float32_precision(c):
|
def ensure_float32_precision(c):
|
||||||
orig_dtype = c.dtype
|
orig_dtype = c.dtype
|
||||||
c = c.float()
|
c = c.float()
|
||||||
return c, orig_dtype
|
return c, orig_dtype
|
||||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||||
style_feat, _ = ensure_float32_precision(style_feat)
|
style_feat, _ = ensure_float32_precision(style_feat)
|
||||||
|
|
||||||
rgb_to_xyz_matrix = torch.tensor([
|
rgb_to_xyz_matrix = torch.tensor([
|
||||||
[0.4124564, 0.3575761, 0.1804375],
|
[0.4124564, 0.3575761, 0.1804375],
|
||||||
[0.2126729, 0.7151522, 0.0721750],
|
[0.2126729, 0.7151522, 0.0721750],
|
||||||
[0.0193339, 0.1191920, 0.9503041]
|
[0.0193339, 0.1191920, 0.9503041]
|
||||||
], dtype=torch.float32, device=device)
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
xyz_to_rgb_matrix = torch.tensor([
|
xyz_to_rgb_matrix = torch.tensor([
|
||||||
[ 3.2404542, -1.5371385, -0.4985314],
|
[ 3.2404542, -1.5371385, -0.4985314],
|
||||||
[-0.9692660, 1.8760108, 0.0415560],
|
[-0.9692660, 1.8760108, 0.0415560],
|
||||||
[ 0.0556434, -0.2040259, 1.0572252]
|
[ 0.0556434, -0.2040259, 1.0572252]
|
||||||
], dtype=torch.float32, device=device)
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
epsilon = 6.0 / 29.0
|
epsilon = 6.0 / 29.0
|
||||||
kappa = (29.0 / 3.0) ** 3
|
kappa = (29.0 / 3.0) ** 3
|
||||||
|
|
||||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
|
|
||||||
# Convert to LAB color space
|
# Convert to LAB color space
|
||||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
del content_feat
|
del content_feat
|
||||||
|
|
||||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
del style_feat, rgb_to_xyz_matrix
|
del style_feat, rgb_to_xyz_matrix
|
||||||
|
|
||||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||||
|
|
||||||
# Handle luminance with weighted blending
|
# Handle luminance with weighted blending
|
||||||
if luminance_weight < 1.0:
|
if luminance_weight < 1.0:
|
||||||
# Partially match luminance for better overall color accuracy
|
# Partially match luminance for better overall color accuracy
|
||||||
@ -1832,23 +1832,23 @@ def lab_color_transfer(
|
|||||||
else:
|
else:
|
||||||
# Fully preserve content luminance
|
# Fully preserve content luminance
|
||||||
result_L = content_lab[:, 0]
|
result_L = content_lab[:, 0]
|
||||||
|
|
||||||
del content_lab, style_lab
|
del content_lab, style_lab
|
||||||
|
|
||||||
# Reconstruct LAB with corrected channels
|
# Reconstruct LAB with corrected channels
|
||||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||||
del result_L, matched_a, matched_b
|
del result_L, matched_a, matched_b
|
||||||
|
|
||||||
# Convert back to RGB
|
# Convert back to RGB
|
||||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||||
del result_lab, xyz_to_rgb_matrix
|
del result_lab, xyz_to_rgb_matrix
|
||||||
|
|
||||||
# Convert back to [-1, 1] range (in-place)
|
# Convert back to [-1, 1] range (in-place)
|
||||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||||
del result_rgb
|
del result_rgb
|
||||||
|
|
||||||
result = result.to(original_dtype)
|
result = result.to(original_dtype)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class VideoAutoencoderKL(nn.Module):
|
class VideoAutoencoderKL(nn.Module):
|
||||||
@ -2140,10 +2140,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
input = input.to(x.device)
|
input = input.to(x.device)
|
||||||
o_h, o_w = self.img_dims
|
o_h, o_w = self.img_dims
|
||||||
|
x = x[..., :o_h, :o_w]
|
||||||
|
input = input[..., :o_h, :o_w ]
|
||||||
x = lab_color_transfer(x, input)
|
x = lab_color_transfer(x, input)
|
||||||
|
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
x = x[..., :o_h, :o_w]
|
|
||||||
x = rearrange(x, "b t c h w -> b c t h w")
|
x = rearrange(x, "b t c h w -> b c t h w")
|
||||||
|
|
||||||
# ensure even dims for save video
|
# ensure even dims for save video
|
||||||
@ -2154,8 +2155,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"):
|
||||||
set_norm_limit(norm_max_mem)
|
set_norm_limit(norm_max_mem)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, InflatedCausalConv3d):
|
if isinstance(m, InflatedCausalConv3d):
|
||||||
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, InflatedCausalConv3d):
|
||||||
|
module.set_memory_device(memory_device)
|
||||||
|
|||||||
@ -1307,9 +1307,6 @@ class SeedVR2(supported_models_base.BASE):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "seedvr2"
|
"image_model": "seedvr2"
|
||||||
}
|
}
|
||||||
sampling_settings = {
|
|
||||||
"shift": 1.0,
|
|
||||||
}
|
|
||||||
latent_format = comfy.latent_formats.SeedVR2
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user