diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b54f7f39a..cb44ee6e8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -93,6 +93,50 @@ class IndexListCallbacks: return {} +def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]): + if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)): + return None + cond_tensor = cond_value.cond + if temporal_dim >= cond_tensor.ndim: + return None + + cond_size = cond_tensor.size(temporal_dim) + + if temporal_scale == 1: + expected_size = x_in.size(window.dim) - temporal_offset + if cond_size != expected_size: + return None + + if temporal_offset == 0 and temporal_scale == 1: + sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list) + return cond_value._copy_with(sliced) + + # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) + if temporal_offset > 0: + indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + indices = [i for i in indices if 0 <= i] + else: + indices = list(window.index_list) + + if not indices: + return None + + if temporal_scale > 1: + scaled = [] + for i in indices: + for k in range(temporal_scale): + si = i * temporal_scale + k + if si < cond_size: + scaled.append(si) + indices = scaled + if not indices: + return None + + idx = tuple([slice(None)] * temporal_dim + [indices]) + sliced = cond_tensor[idx].to(device) + return cond_value._copy_with(sliced) + + @dataclass class ContextSchedule: name: str @@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item[cond_key] = result handled = True break + if not handled and self._model is not None: + result = self._model.resize_cond_for_context_window( + cond_key, cond_value, window, x_in, device, + retain_index_list=self.cond_retain_index_list) + if result is not None: + new_cond_item[cond_key] = result + handled = True if handled: continue if isinstance(cond_value, torch.Tensor): - if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) # Handle audio_embed (temporal dim is 1) @@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC): return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self._model = model self.set_step(timestep, model_options) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 356394239..7515f0d4e 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -23,6 +23,11 @@ class CausalConv3d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels + if isinstance(stride, int): + self.time_stride = stride + else: + self.time_stride = stride[0] + kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] @@ -58,18 +63,23 @@ class CausalConv3d(nn.Module): pieces = [ cached, x ] if is_end and not causal: pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + input_length = sum([piece.shape[2] for piece in pieces]) + cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride) needs_caching = not is_end - if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + if needs_caching and cache_length == 0: + self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False) needs_caching = False - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + if needs_caching and x.shape[2] >= cache_length: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) x = torch.cat(pieces, dim=2) del pieces del cached if needs_caching: - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) elif is_end: self.temporal_cache_state[tid] = (None, True) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index f7aae26da..998122c85 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -233,10 +233,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]: sample = self.conv_in(sample) checkpoint_fn = ( @@ -247,10 +244,14 @@ class Encoder(nn.Module): for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) + if sample is None or sample.shape[2] == 0: + return None sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + if sample is None or sample.shape[2] == 0: + return None if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] @@ -282,9 +283,35 @@ class Encoder(nn.Module): return sample + def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder + frame_size = sample[:, :, :1, :, :].numel() * sample.element_size() + frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels)) + + outputs = [] + samples = [sample[:, :, :1, :, :]] + if sample.shape[2] > 1: + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) + for chunk_idx, chunk in enumerate(samples): + if chunk_idx == len(samples) - 1: + mark_conv3d_ended(self) + chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device) + output = self._forward_chunk(chunk) + if output is not None: + outputs.append(output) + + return torch_cat_if_needed(outputs, dim=2) + def forward(self, *args, **kwargs): - #No encoder support so just flag the end so it doesnt use the cache. - mark_conv3d_ended(self) try: return self.forward_orig(*args, **kwargs) finally: @@ -509,6 +536,53 @@ class Decoder(nn.Module): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + sample = sample_ref[0] + sample_ref[0] = None + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t + return + + up_block = self.up_blocks[idx] + if ended: + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + def forward_orig( self, sample: torch.FloatTensor, @@ -528,6 +602,7 @@ class Decoder(nn.Module): ) timestep_shift_scale = None + scaled_timestep = None if self.timestep_conditioning: assert ( timestep is not None @@ -564,54 +639,7 @@ class Decoder(nn.Module): max_chunk_size = get_max_chunk_size(sample.device) - def run_up(idx, sample_ref, ended): - sample = sample_ref[0] - sample_ref[0] = None - if idx >= len(self.up_blocks): - sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale - sample = sample * (1 + scale) + shift - sample = self.conv_act(sample) - if ended: - mark_conv3d_ended(self.conv_out) - sample = self.conv_out(sample, causal=self.causal) - if sample is not None and sample.shape[2] > 0: - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - t = sample.shape[2] - output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) - output_offset[0] += t - return - - up_block = self.up_blocks[idx] - if (ended): - mark_conv3d_ended(up_block) - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - if sample is None or sample.shape[2] == 0: - return - - total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size - - if num_chunks == 1: - # when we are not chunking, detach our x so the callee can free it as soon as they are done - next_sample_ref = [sample] - del sample - run_up(idx + 1, next_sample_ref, ended) - return - else: - samples = torch.chunk(sample, chunks=num_chunks, dim=2) - - for chunk_idx, sample1 in enumerate(samples): - run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) - - run_up(0, [sample], True) + self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) return output_buffer @@ -737,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module): causal=True, spatial_padding_mode=spatial_padding_mode, ) + self.temporal_cache_state = {} def forward(self, x, causal: bool = True): - if self.stride[0] == 2: + tid = threading.get_ident() + cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None)) + if cached_input is not None: + x = torch_cat_if_needed([cached_input, x], dim=2) + cached_input = None + + if self.stride[0] == 2 and pad_first: x = torch.cat( [x[:, :, :1, :, :], x], dim=2 ) # duplicate first frames for padding + pad_first = False + + if x.shape[2] < self.stride[0]: + cached_input = x + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) + return None # skip connection x_in = rearrange( @@ -757,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module): # conv x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) + if self.stride[0] == 2 and x.shape[2] == 1: + if cached_x is not None: + x = torch_cat_if_needed([cached_x, x], dim=2) + cached_x = None + else: + cached_x = x + x = None - x = x + x_in + if x is not None: + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + cached = add_exchange_cache(x, cached, x_in, dim=2) + + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) return x @@ -1098,6 +1150,8 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): + comfy_has_chunked_io = True + def __init__(self, version=0, config=None): super().__init__() @@ -1240,11 +1294,9 @@ class VideoVAE(nn.Module): } return config - def encode(self, x): - frames_count = x.shape[2] - if ((frames_count - 1) % 8) != 0: - raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") - means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + def encode(self, x, device=None): + x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :] + means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1) return self.per_channel_statistics.normalize(means) def decode_output_shape(self, input_shape): diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a96b83c6c..deeb8695b 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -360,6 +360,43 @@ class Decoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) + def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + self.run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_cache, + feat_idx.copy(), + out_chunks, + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + next_x_ref = [x] + del x + self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) + def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: @@ -380,42 +417,7 @@ class Decoder3d(nn.Module): out_chunks = [] - def run_up(layer_idx, x_ref, feat_idx): - x = x_ref[0] - x_ref[0] = None - if layer_idx >= len(self.upsamples): - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - cache_x = x[:, :, -CACHE_T:, :, :] - x = layer(x, feat_cache[feat_idx[0]]) - feat_cache[feat_idx[0]] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - out_chunks.append(x) - return - - layer = self.upsamples[layer_idx] - if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: - for frame_idx in range(x.shape[2]): - run_up( - layer_idx, - [x[:, :, frame_idx:frame_idx + 1, :, :]], - feat_idx.copy(), - ) - del x - return - - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - next_x_ref = [x] - del x - run_up(layer_idx + 1, next_x_ref, feat_idx) - - run_up(0, [x], feat_idx) + self.run_up(0, [x], feat_cache, feat_idx, out_chunks) return out_chunks diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 563224098..f9078fe7c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination): if (destination.device.type != "cpu" or file_obj is None or threading.get_ident() != info.thread_id - or destination.numel() * destination.element_size() < info.size): + or destination.numel() * destination.element_size() < info.size + or tensor.numel() * tensor.element_size() != info.size + or tensor.storage_offset() != 0 + or not tensor.is_contiguous()): return False if info.size == 0: diff --git a/comfy/model_base.py b/comfy/model_base.py index d9d5a9293..88905e191 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -285,6 +285,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1375,6 +1381,12 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1427,6 +1439,12 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1444,6 +1462,14 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + import comfy.context_windows + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1480,6 +1506,12 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5f2e6ef67..2c250dacc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1003,7 +1003,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: diff --git a/comfy/sd.py b/comfy/sd.py index 1f9510959..e207bb0fd 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -953,7 +953,7 @@ class VAE: # Pre-allocate output for VAEs that support direct buffer writes preallocated = False - if hasattr(self.first_stage_model, 'decode_output_shape'): + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) preallocated = True @@ -978,6 +978,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1038,8 +1039,13 @@ class VAE: batch_number = max(1, batch_number) samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype()) + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out @@ -1054,6 +1060,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 diff --git a/comfy/utils.py b/comfy/utils.py index 13b7ca6c8..78c491b98 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out = output[b:b+1].zero_() + out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] @@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) @@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 6dbd5984e..de0c22e70 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -47,6 +47,10 @@ SEEDREAM_MODELS = { BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} + +logger = logging.getLogger(__name__) + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: @@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode): price_badge=IO.PriceBadge( expr="""{"type":"usd","usd":0.03}""", ), + is_deprecated=True, ) @classmethod @@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ] return await process_video_task( cls, - payload=Image2VideoTaskCreationRequest(model=model, content=x), + payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -952,6 +957,12 @@ async def process_video_task( payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, estimated_duration: int | None, ) -> IO.NodeOutput: + if payload.model in DEPRECATED_MODELS: + logger.warning( + "Model '%s' is deprecated and will be deactivated on May 13, 2026. " + "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", + payload.model, + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 93a5204e1..0e43f2e44 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode): io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."),