mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
fp16 check + changed to lab
This commit is contained in:
parent
72ca18acc2
commit
3c149dd543
@ -768,9 +768,9 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
|
||||
out = optimized_attention(
|
||||
q=concat_win(vid_q, txt_q).bfloat16(),
|
||||
k=concat_win(vid_k, txt_k).bfloat16(),
|
||||
v=concat_win(vid_v, txt_v).bfloat16(),
|
||||
q=concat_win(vid_q, txt_q),
|
||||
k=concat_win(vid_k, txt_k),
|
||||
v=concat_win(vid_v, txt_v),
|
||||
heads=self.heads, skip_reshape=True, var_length = True,
|
||||
cu_seqlens_q=cache_win(
|
||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
|
||||
@ -1626,6 +1626,231 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(0.95047)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(1.08883)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(0.95047) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(1.08883) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = 6.0 / 29.0
|
||||
kappa = (29.0 / 3.0) ** 3
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1914,10 +2139,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
x = rearrange(x, exp)
|
||||
|
||||
input = input.to(x.device)
|
||||
x = wavelet_reconstruction(x, input)
|
||||
o_h, o_w = self.img_dims
|
||||
x = lab_color_transfer(x, input)
|
||||
|
||||
x = x.unsqueeze(0)
|
||||
o_h, o_w = self.img_dims
|
||||
x = x[..., :o_h, :o_w]
|
||||
x = rearrange(x, "b t c h w -> b c t h w")
|
||||
|
||||
|
||||
@ -1292,11 +1292,14 @@ class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
@ -407,6 +407,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Latent.Input("vae_conditioning"),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("latent_noise_scale", default=0.0, step=0.001)
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||
io.Conditioning.Output(display_name = "negative"),
|
||||
@ -414,7 +415,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
|
||||
def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput:
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
device = vae_conditioning.device
|
||||
@ -425,7 +426,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||
cond_noise_scale = 0.0
|
||||
cond_noise_scale = latent_noise_scale
|
||||
t = (
|
||||
torch.tensor([1000.0])
|
||||
* cond_noise_scale
|
||||
|
||||
Loading…
Reference in New Issue
Block a user