Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
Deep Mehta 2026-03-18 17:14:08 -07:00 committed by GitHub
commit b20cb7892e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 70 additions and 23 deletions

View File

@ -473,6 +473,17 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning self.timestep_conditioning = timestep_conditioning
if timestep_conditioning: if timestep_conditioning:
@ -494,11 +505,15 @@ class Decoder(nn.Module):
) )
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def forward_orig( def forward_orig(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0] batch_size = sample.shape[0]
@ -540,7 +555,13 @@ class Decoder(nn.Module):
) )
timestep_shift_scale = ada_values.unbind(dim=1) timestep_shift_scale = ada_values.unbind(dim=1)
output = [] if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
max_chunk_size = get_max_chunk_size(sample.device) max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample_ref, ended): def run_up(idx, sample_ref, ended):
@ -556,7 +577,10 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out) mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal) sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0: if sample is not None and sample.shape[2] > 0:
output.append(sample.to(comfy.model_management.intermediate_device())) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return return
up_block = self.up_blocks[idx] up_block = self.up_blocks[idx]
@ -588,11 +612,8 @@ class Decoder(nn.Module):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True) run_up(0, [sample], True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return output_buffer
return sample
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
try: try:
@ -1226,7 +1247,10 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1) means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means) return self.per_channel_statistics.normalize(means)
def decode(self, x): def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)

View File

@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples return samples

View File

@ -951,12 +951,23 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if hasattr(self.first_stage_model, 'decode_output_shape'):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) if preallocated:
if pixel_samples is None: self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) else:
pixel_samples[x:x+batch_number] = out out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e: except Exception as e:
model_management.raise_non_oom(e) model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")

View File

@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2] out, pooled = o[:2]
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device()) first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else: else:
first_pooled = pooled first_pooled = pooled
@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
r = (out[-1:].to(model_management.intermediate_device()), first_pooled) r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else: else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
if len(o) > 2: if len(o) > 2:
extra = {} extra = {}
for k in o[2]: for k in o[2]:
v = o[2][k] v = o[2][k]
if k == "attention_mask": if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v extra[k] = v
r = r + (extra,) r = r + (extra,)

View File

@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None) inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None) fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None) text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel): class GeminiTextPart(BaseModel):

View File

@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model; $m := widgets.model;
$r := widgets.resolution; $r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2"); $isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices; $prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts]) return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
image_tensors: list[Input.Image] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*") parts = get_parts_by_type(response, "image/*")
for part in parts: for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData: if part.inlineData:
image_data = base64.b64decode(part.inlineData.data) image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data)) returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
IO.String.Output(), IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.41.20 comfyui-frontend-package==1.41.21
comfyui-workflow-templates==0.9.26 comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch