diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 9c9edadce..eb2237eee 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1353,7 +1353,8 @@ class NaDiT(nn.Module): pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) txt, txt_shape = flatten([pos_cond, neg_cond]) except: - txt, txt_shape = flatten([context.squeeze(0)]) + context = self.positive_conditioning + txt, txt_shape = flatten([context]) vid, vid_shape = flatten(x) cond_latent, _ = flatten(conditions) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 2a45cf450..18812dde8 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1628,20 +1628,20 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: original_shape = source.shape - + # Flatten source_flat = source.flatten() reference_flat = reference.flatten() - + # Sort both arrays source_sorted, source_indices = torch.sort(source_flat) reference_sorted, _ = torch.sort(reference_flat) del reference_flat - + # Quantile mapping n_source = len(source_sorted) n_reference = len(reference_sorted) - + if n_source == n_reference: matched_sorted = reference_sorted else: @@ -1651,27 +1651,27 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch ref_indices.clamp_(0, n_reference - 1) matched_sorted = reference_sorted[ref_indices] del source_quantiles, ref_indices, reference_sorted - + del source_sorted, source_flat - + # Reconstruct using argsort (portable across CUDA/ROCm/MPS) inverse_indices = torch.argsort(source_indices) del source_indices matched_flat = matched_sorted[inverse_indices] del matched_sorted, inverse_indices - + return matched_flat.reshape(original_shape) def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: """Convert batch of CIELAB images to RGB color space.""" L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - + # LAB to XYZ fy = (L + 16.0) / 116.0 fx = a.div(500.0).add_(fy) fz = fy - b / 200.0 del L, a, b - + # XYZ transformation x = torch.where( fx > epsilon, @@ -1689,28 +1689,28 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps fz.mul(116.0).sub_(16.0).div_(kappa) ) del fx, fy, fz - + # Apply D65 white point (in-place) x.mul_(0.95047) # y *= 1.00000 # (no-op, skip) z.mul_(1.08883) - + xyz = torch.stack([x, y, z], dim=1) del x, y, z - + # Matrix multiplication: XYZ -> RGB B, C, H, W = xyz.shape xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) del xyz - + # Ensure dtype consistency for matrix multiplication xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) del xyz_flat - + rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del rgb_linear_flat - + # Apply inverse gamma correction (delinearize) mask = rgb_linear > 0.0031308 rgb = torch.where( @@ -1719,7 +1719,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps rgb_linear * 12.92 ) del mask, rgb_linear - + return torch.clamp(rgb, 0.0, 1.0) def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: @@ -1732,25 +1732,25 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon rgb / 12.92 ) del mask - + # Matrix multiplication: RGB -> XYZ B, C, H, W = rgb_linear.shape rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) del rgb_linear - + # Ensure dtype consistency for matrix multiplication rgb_flat = rgb_flat.to(dtype=matrix.dtype) xyz_flat = torch.matmul(rgb_flat, matrix.T) del rgb_flat - + xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del xyz_flat - + # Normalize by D65 white point (in-place) xyz[:, 0].div_(0.95047) # X # xyz[:, 1] /= 1.00000 # Y (no-op, skip) xyz[:, 2].div_(1.08883) # Z - + # XYZ to LAB transformation epsilon_cubed = epsilon ** 3 mask = xyz > epsilon_cubed @@ -1760,13 +1760,13 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon xyz.mul(kappa).add_(16.0).div_(116.0) ) del xyz, mask - + # Extract channels and compute LAB L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] del f_xyz - + return torch.stack([L, a, b], dim=1) def lab_color_transfer( @@ -1775,7 +1775,7 @@ def lab_color_transfer( luminance_weight: float = 0.8 ) -> Tensor: content_feat = wavelet_reconstruction(content_feat, style_feat) - + if content_feat.shape != style_feat.shape: style_feat = safe_interpolate_operation( style_feat, @@ -1783,45 +1783,45 @@ def lab_color_transfer( mode='bilinear', align_corners=False ) - + device = content_feat.device - + def ensure_float32_precision(c): orig_dtype = c.dtype c = c.float() return c, orig_dtype content_feat, original_dtype = ensure_float32_precision(content_feat) style_feat, _ = ensure_float32_precision(style_feat) - + rgb_to_xyz_matrix = torch.tensor([ [0.4124564, 0.3575761, 0.1804375], [0.2126729, 0.7151522, 0.0721750], [0.0193339, 0.1191920, 0.9503041] ], dtype=torch.float32, device=device) - + xyz_to_rgb_matrix = torch.tensor([ [ 3.2404542, -1.5371385, -0.4985314], [-0.9692660, 1.8760108, 0.0415560], [ 0.0556434, -0.2040259, 1.0572252] ], dtype=torch.float32, device=device) - + epsilon = 6.0 / 29.0 kappa = (29.0 / 3.0) ** 3 - + content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - + # Convert to LAB color space content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) del content_feat - + style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) del style_feat, rgb_to_xyz_matrix - + # Match chrominance channels (a*, b*) for accurate color transfer matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) - + # Handle luminance with weighted blending if luminance_weight < 1.0: # Partially match luminance for better overall color accuracy @@ -1832,23 +1832,23 @@ def lab_color_transfer( else: # Fully preserve content luminance result_L = content_lab[:, 0] - + del content_lab, style_lab - + # Reconstruct LAB with corrected channels result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) del result_L, matched_a, matched_b - + # Convert back to RGB result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) del result_lab, xyz_to_rgb_matrix - + # Convert back to [-1, 1] range (in-place) result = result_rgb.mul_(2.0).sub_(1.0) del result_rgb - + result = result.to(original_dtype) - + return result class VideoAutoencoderKL(nn.Module): @@ -2140,10 +2140,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): input = input.to(x.device) o_h, o_w = self.img_dims + x = x[..., :o_h, :o_w] + input = input[..., :o_h, :o_w ] x = lab_color_transfer(x, input) x = x.unsqueeze(0) - x = x[..., :o_h, :o_w] x = rearrange(x, "b t c h w -> b c t h w") # ensure even dims for save video @@ -2154,8 +2155,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return x - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) + + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d047579c4..e846348b2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1307,9 +1307,6 @@ class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" } - sampling_settings = { - "shift": 1.0, - } latent_format = comfy.latent_formats.SeedVR2 vae_key_prefix = ["vae."]