diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 145e6e69a..e4e30cacd 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -136,16 +136,7 @@ class ResBlock(nn.Module): ops.Linear(c_hidden, c), ) - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 356394239..7515f0d4e 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -23,6 +23,11 @@ class CausalConv3d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels + if isinstance(stride, int): + self.time_stride = stride + else: + self.time_stride = stride[0] + kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] @@ -58,18 +63,23 @@ class CausalConv3d(nn.Module): pieces = [ cached, x ] if is_end and not causal: pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + input_length = sum([piece.shape[2] for piece in pieces]) + cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride) needs_caching = not is_end - if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + if needs_caching and cache_length == 0: + self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False) needs_caching = False - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + if needs_caching and x.shape[2] >= cache_length: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) x = torch.cat(pieces, dim=2) del pieces del cached if needs_caching: - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) elif is_end: self.temporal_cache_state[tid] = (None, True) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 0504140ef..1a15cafd0 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -233,10 +233,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]: sample = self.conv_in(sample) checkpoint_fn = ( @@ -247,10 +244,14 @@ class Encoder(nn.Module): for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) + if sample is None or sample.shape[2] == 0: + return None sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + if sample is None or sample.shape[2] == 0: + return None if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] @@ -282,9 +283,35 @@ class Encoder(nn.Module): return sample + def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder + frame_size = sample[:, :, :1, :, :].numel() * sample.element_size() + frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels)) + + outputs = [] + samples = [sample[:, :, :1, :, :]] + if sample.shape[2] > 1: + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) + for chunk_idx, chunk in enumerate(samples): + if chunk_idx == len(samples) - 1: + mark_conv3d_ended(self) + chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device) + output = self._forward_chunk(chunk) + if output is not None: + outputs.append(output) + + return torch_cat_if_needed(outputs, dim=2) + def forward(self, *args, **kwargs): - #No encoder support so just flag the end so it doesnt use the cache. - mark_conv3d_ended(self) try: return self.forward_orig(*args, **kwargs) finally: @@ -473,6 +500,17 @@ class Decoder(nn.Module): self.gradient_checkpointing = False + # Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset) + ts, hs, ws, to = 1, 1, 1, 0 + for block in self.up_blocks: + if isinstance(block, DepthToSpaceUpsample): + ts *= block.stride[0] + hs *= block.stride[1] + ws *= block.stride[2] + if block.stride[0] > 1: + to = to * block.stride[0] + 1 + self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to) + self.timestep_conditioning = timestep_conditioning if timestep_conditioning: @@ -494,11 +532,15 @@ class Decoder(nn.Module): ) - # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + def decode_output_shape(self, input_shape): + c, (ts, hs, ws), to = self._output_scale + return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] @@ -540,7 +582,13 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) - output = [] + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] + max_chunk_size = get_max_chunk_size(sample.device) def run_up(idx, sample_ref, ended): @@ -556,7 +604,10 @@ class Decoder(nn.Module): mark_conv3d_ended(self.conv_out) sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: - output.append(sample.to(comfy.model_management.intermediate_device())) + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t return up_block = self.up_blocks[idx] @@ -588,11 +639,8 @@ class Decoder(nn.Module): run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) run_up(0, [sample], True) - sample = torch.cat(output, dim=2) - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - - return sample + return output_buffer def forward(self, *args, **kwargs): try: @@ -716,12 +764,25 @@ class SpaceToDepthDownsample(nn.Module): causal=True, spatial_padding_mode=spatial_padding_mode, ) + self.temporal_cache_state = {} def forward(self, x, causal: bool = True): - if self.stride[0] == 2: + tid = threading.get_ident() + cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None)) + if cached_input is not None: + x = torch_cat_if_needed([cached_input, x], dim=2) + cached_input = None + + if self.stride[0] == 2 and pad_first: x = torch.cat( [x[:, :, :1, :, :], x], dim=2 ) # duplicate first frames for padding + pad_first = False + + if x.shape[2] < self.stride[0]: + cached_input = x + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) + return None # skip connection x_in = rearrange( @@ -736,15 +797,26 @@ class SpaceToDepthDownsample(nn.Module): # conv x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) + if self.stride[0] == 2 and x.shape[2] == 1: + if cached_x is not None: + x = torch_cat_if_needed([cached_x, x], dim=2) + cached_x = None + else: + cached_x = x + x = None - x = x + x_in + if x is not None: + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + cached = add_exchange_cache(x, cached, x_in, dim=2) + + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) return x @@ -1077,6 +1149,8 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): + comfy_has_chunked_io = True + def __init__(self, version=0, config=None): super().__init__() @@ -1219,14 +1293,15 @@ class VideoVAE(nn.Module): } return config - def encode(self, x): - frames_count = x.shape[2] - if ((frames_count - 1) % 8) != 0: - raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") - means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + def encode(self, x, device=None): + x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :] + means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x): + def decode_output_shape(self, input_shape): + return self.decoder.decode_output_shape(input_shape) + + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 563224098..f9078fe7c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination): if (destination.device.type != "cpu" or file_obj is None or threading.get_ident() != info.thread_id - or destination.numel() * destination.element_size() < info.size): + or destination.numel() * destination.element_size() < info.size + or tensor.numel() * tensor.element_size() != info.size + or tensor.storage_offset() != 0 + or not tensor.is_contiguous()): return False if info.size == 0: diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..5f2e6ef67 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1003,7 +1003,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: diff --git a/comfy/sample.py b/comfy/sample.py index a2a39b527..e9c2259ab 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples diff --git a/comfy/sd.py b/comfy/sd.py index 4d427bb9a..b5e7c93a9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -455,7 +455,7 @@ class VAE: self.output_channels = 3 self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 - self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False @@ -951,12 +951,23 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True + for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype())) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number] = out + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) except Exception as e: model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") @@ -1027,8 +1038,13 @@ class VAE: batch_number = max(1, batch_number) samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype()) + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d89550840..0eb30df27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,7 +46,7 @@ class ClipTokenWeightEncoder: out, pooled = o[:2] if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) + first_pooled = pooled[0:1].to(device=model_management.intermediate_device()) else: first_pooled = pooled @@ -63,16 +63,16 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled) else: - r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled) if len(o) > 2: extra = {} for k in o[2]: v = o[2][k] if k == "attention_mask": - v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device()) extra[k] = v r = r + (extra,) diff --git a/comfy/utils.py b/comfy/utils.py index 13b7ca6c8..78c491b98 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out = output[b:b+1].zero_() + out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] @@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) @@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 639035fef..22879fe18 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -67,6 +67,7 @@ class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) + thought: bool | None = Field(None) class GeminiTextPart(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 8225ea67e..25d747e76 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( $m := widgets.model; $r := widgets.resolution; $isFlash := $contains($m, "nano banana 2"); - $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; $prices := $isFlash ? $flashPrices : $proPrices; {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} @@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/*") for part in parts: + if (part.thought is True) != thought: + continue if part.inlineData: image_data = base64.b64decode(part.inlineData.data) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) @@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode): outputs=[ IO.Image.Output(), IO.String.Output(), + IO.Image.Output( + display_name="thought_image", + tooltip="First image from the model's thinking process. " + "Only available with thinking_level HIGH and IMAGE+TEXT modality.", + ), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput( + await get_image_from_response(response), + get_text_from_response(response), + await get_image_from_response(response, thought=True), + ) class GeminiExtension(ComfyExtension): diff --git a/requirements.txt b/requirements.txt index 0ce163f71..ad0344ed4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.20 +comfyui-frontend-package==1.41.21 comfyui-workflow-templates==0.9.26 comfyui-embedded-docs==0.4.3 torch