mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:20:27 +08:00
fp16 check + changed to lab
This commit is contained in:
parent
72ca18acc2
commit
3c149dd543
@ -768,9 +768,9 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
|
||||||
out = optimized_attention(
|
out = optimized_attention(
|
||||||
q=concat_win(vid_q, txt_q).bfloat16(),
|
q=concat_win(vid_q, txt_q),
|
||||||
k=concat_win(vid_k, txt_k).bfloat16(),
|
k=concat_win(vid_k, txt_k),
|
||||||
v=concat_win(vid_v, txt_v).bfloat16(),
|
v=concat_win(vid_v, txt_v),
|
||||||
heads=self.heads, skip_reshape=True, var_length = True,
|
heads=self.heads, skip_reshape=True, var_length = True,
|
||||||
cu_seqlens_q=cache_win(
|
cu_seqlens_q=cache_win(
|
||||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||||
|
|||||||
@ -1626,6 +1626,231 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
|||||||
|
|
||||||
return content_high_freq.clamp_(-1.0, 1.0)
|
return content_high_freq.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
|
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||||
|
original_shape = source.shape
|
||||||
|
|
||||||
|
# Flatten
|
||||||
|
source_flat = source.flatten()
|
||||||
|
reference_flat = reference.flatten()
|
||||||
|
|
||||||
|
# Sort both arrays
|
||||||
|
source_sorted, source_indices = torch.sort(source_flat)
|
||||||
|
reference_sorted, _ = torch.sort(reference_flat)
|
||||||
|
del reference_flat
|
||||||
|
|
||||||
|
# Quantile mapping
|
||||||
|
n_source = len(source_sorted)
|
||||||
|
n_reference = len(reference_sorted)
|
||||||
|
|
||||||
|
if n_source == n_reference:
|
||||||
|
matched_sorted = reference_sorted
|
||||||
|
else:
|
||||||
|
# Interpolate reference to match source quantiles
|
||||||
|
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||||
|
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||||
|
ref_indices.clamp_(0, n_reference - 1)
|
||||||
|
matched_sorted = reference_sorted[ref_indices]
|
||||||
|
del source_quantiles, ref_indices, reference_sorted
|
||||||
|
|
||||||
|
del source_sorted, source_flat
|
||||||
|
|
||||||
|
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||||
|
inverse_indices = torch.argsort(source_indices)
|
||||||
|
del source_indices
|
||||||
|
matched_flat = matched_sorted[inverse_indices]
|
||||||
|
del matched_sorted, inverse_indices
|
||||||
|
|
||||||
|
return matched_flat.reshape(original_shape)
|
||||||
|
|
||||||
|
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
|
"""Convert batch of CIELAB images to RGB color space."""
|
||||||
|
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||||
|
|
||||||
|
# LAB to XYZ
|
||||||
|
fy = (L + 16.0) / 116.0
|
||||||
|
fx = a.div(500.0).add_(fy)
|
||||||
|
fz = fy - b / 200.0
|
||||||
|
del L, a, b
|
||||||
|
|
||||||
|
# XYZ transformation
|
||||||
|
x = torch.where(
|
||||||
|
fx > epsilon,
|
||||||
|
torch.pow(fx, 3.0),
|
||||||
|
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
y = torch.where(
|
||||||
|
fy > epsilon,
|
||||||
|
torch.pow(fy, 3.0),
|
||||||
|
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
z = torch.where(
|
||||||
|
fz > epsilon,
|
||||||
|
torch.pow(fz, 3.0),
|
||||||
|
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
del fx, fy, fz
|
||||||
|
|
||||||
|
# Apply D65 white point (in-place)
|
||||||
|
x.mul_(0.95047)
|
||||||
|
# y *= 1.00000 # (no-op, skip)
|
||||||
|
z.mul_(1.08883)
|
||||||
|
|
||||||
|
xyz = torch.stack([x, y, z], dim=1)
|
||||||
|
del x, y, z
|
||||||
|
|
||||||
|
# Matrix multiplication: XYZ -> RGB
|
||||||
|
B, C, H, W = xyz.shape
|
||||||
|
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
|
del xyz
|
||||||
|
|
||||||
|
# Ensure dtype consistency for matrix multiplication
|
||||||
|
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||||
|
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||||
|
del xyz_flat
|
||||||
|
|
||||||
|
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
|
del rgb_linear_flat
|
||||||
|
|
||||||
|
# Apply inverse gamma correction (delinearize)
|
||||||
|
mask = rgb_linear > 0.0031308
|
||||||
|
rgb = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||||
|
rgb_linear * 12.92
|
||||||
|
)
|
||||||
|
del mask, rgb_linear
|
||||||
|
|
||||||
|
return torch.clamp(rgb, 0.0, 1.0)
|
||||||
|
|
||||||
|
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
|
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||||
|
# Apply sRGB gamma correction (linearize)
|
||||||
|
mask = rgb > 0.04045
|
||||||
|
rgb_linear = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||||
|
rgb / 12.92
|
||||||
|
)
|
||||||
|
del mask
|
||||||
|
|
||||||
|
# Matrix multiplication: RGB -> XYZ
|
||||||
|
B, C, H, W = rgb_linear.shape
|
||||||
|
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
|
del rgb_linear
|
||||||
|
|
||||||
|
# Ensure dtype consistency for matrix multiplication
|
||||||
|
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||||
|
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||||
|
del rgb_flat
|
||||||
|
|
||||||
|
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
|
del xyz_flat
|
||||||
|
|
||||||
|
# Normalize by D65 white point (in-place)
|
||||||
|
xyz[:, 0].div_(0.95047) # X
|
||||||
|
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||||
|
xyz[:, 2].div_(1.08883) # Z
|
||||||
|
|
||||||
|
# XYZ to LAB transformation
|
||||||
|
epsilon_cubed = epsilon ** 3
|
||||||
|
mask = xyz > epsilon_cubed
|
||||||
|
f_xyz = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow(xyz, 1.0 / 3.0),
|
||||||
|
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||||
|
)
|
||||||
|
del xyz, mask
|
||||||
|
|
||||||
|
# Extract channels and compute LAB
|
||||||
|
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||||
|
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||||
|
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||||
|
del f_xyz
|
||||||
|
|
||||||
|
return torch.stack([L, a, b], dim=1)
|
||||||
|
|
||||||
|
def lab_color_transfer(
|
||||||
|
content_feat: Tensor,
|
||||||
|
style_feat: Tensor,
|
||||||
|
luminance_weight: float = 0.8
|
||||||
|
) -> Tensor:
|
||||||
|
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||||
|
|
||||||
|
if content_feat.shape != style_feat.shape:
|
||||||
|
style_feat = safe_interpolate_operation(
|
||||||
|
style_feat,
|
||||||
|
size=content_feat.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
device = content_feat.device
|
||||||
|
|
||||||
|
def ensure_float32_precision(c):
|
||||||
|
orig_dtype = c.dtype
|
||||||
|
c = c.float()
|
||||||
|
return c, orig_dtype
|
||||||
|
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||||
|
style_feat, _ = ensure_float32_precision(style_feat)
|
||||||
|
|
||||||
|
rgb_to_xyz_matrix = torch.tensor([
|
||||||
|
[0.4124564, 0.3575761, 0.1804375],
|
||||||
|
[0.2126729, 0.7151522, 0.0721750],
|
||||||
|
[0.0193339, 0.1191920, 0.9503041]
|
||||||
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
xyz_to_rgb_matrix = torch.tensor([
|
||||||
|
[ 3.2404542, -1.5371385, -0.4985314],
|
||||||
|
[-0.9692660, 1.8760108, 0.0415560],
|
||||||
|
[ 0.0556434, -0.2040259, 1.0572252]
|
||||||
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
epsilon = 6.0 / 29.0
|
||||||
|
kappa = (29.0 / 3.0) ** 3
|
||||||
|
|
||||||
|
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
|
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
|
|
||||||
|
# Convert to LAB color space
|
||||||
|
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
|
del content_feat
|
||||||
|
|
||||||
|
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
|
del style_feat, rgb_to_xyz_matrix
|
||||||
|
|
||||||
|
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||||
|
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||||
|
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||||
|
|
||||||
|
# Handle luminance with weighted blending
|
||||||
|
if luminance_weight < 1.0:
|
||||||
|
# Partially match luminance for better overall color accuracy
|
||||||
|
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||||
|
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||||
|
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||||
|
del matched_L
|
||||||
|
else:
|
||||||
|
# Fully preserve content luminance
|
||||||
|
result_L = content_lab[:, 0]
|
||||||
|
|
||||||
|
del content_lab, style_lab
|
||||||
|
|
||||||
|
# Reconstruct LAB with corrected channels
|
||||||
|
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||||
|
del result_L, matched_a, matched_b
|
||||||
|
|
||||||
|
# Convert back to RGB
|
||||||
|
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||||
|
del result_lab, xyz_to_rgb_matrix
|
||||||
|
|
||||||
|
# Convert back to [-1, 1] range (in-place)
|
||||||
|
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||||
|
del result_rgb
|
||||||
|
|
||||||
|
result = result.to(original_dtype)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
class VideoAutoencoderKL(nn.Module):
|
class VideoAutoencoderKL(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1914,10 +2139,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
x = rearrange(x, exp)
|
x = rearrange(x, exp)
|
||||||
|
|
||||||
input = input.to(x.device)
|
input = input.to(x.device)
|
||||||
x = wavelet_reconstruction(x, input)
|
o_h, o_w = self.img_dims
|
||||||
|
x = lab_color_transfer(x, input)
|
||||||
|
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
o_h, o_w = self.img_dims
|
|
||||||
x = x[..., :o_h, :o_w]
|
x = x[..., :o_h, :o_w]
|
||||||
x = rearrange(x, "b t c h w -> b c t h w")
|
x = rearrange(x, "b t c h w -> b c t h w")
|
||||||
|
|
||||||
|
|||||||
@ -1292,11 +1292,14 @@ class SeedVR2(supported_models_base.BASE):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "seedvr2"
|
"image_model": "seedvr2"
|
||||||
}
|
}
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
latent_format = comfy.latent_formats.SeedVR2
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
sampling_settings = {
|
sampling_settings = {
|
||||||
"shift": 1.0,
|
"shift": 1.0,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -407,6 +407,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input("vae_conditioning"),
|
io.Latent.Input("vae_conditioning"),
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("latent_noise_scale", default=0.0, step=0.001)
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||||
io.Conditioning.Output(display_name = "negative"),
|
io.Conditioning.Output(display_name = "negative"),
|
||||||
@ -414,7 +415,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
|
def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput:
|
||||||
|
|
||||||
vae_conditioning = vae_conditioning["samples"]
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
device = vae_conditioning.device
|
device = vae_conditioning.device
|
||||||
@ -425,7 +426,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
noises = torch.randn_like(vae_conditioning).to(device)
|
noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||||
cond_noise_scale = 0.0
|
cond_noise_scale = latent_noise_scale
|
||||||
t = (
|
t = (
|
||||||
torch.tensor([1000.0])
|
torch.tensor([1000.0])
|
||||||
* cond_noise_scale
|
* cond_noise_scale
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user