mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
Merge branch 'Comfy-Org:master' into fix/jobs-preview-fallback-priority
This commit is contained in:
commit
5931332726
@ -136,16 +136,7 @@ class ResBlock(nn.Module):
|
|||||||
ops.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||||
|
|
||||||
# Init weights
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
def _norm(self, x, norm):
|
def _norm(self, x, norm):
|
||||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|||||||
@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
self.time_stride = stride
|
||||||
|
else:
|
||||||
|
self.time_stride = stride[0]
|
||||||
|
|
||||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
self.time_kernel_size = kernel_size[0]
|
self.time_kernel_size = kernel_size[0]
|
||||||
|
|
||||||
@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
|
|||||||
pieces = [ cached, x ]
|
pieces = [ cached, x ]
|
||||||
if is_end and not causal:
|
if is_end and not causal:
|
||||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||||
|
input_length = sum([piece.shape[2] for piece in pieces])
|
||||||
|
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||||
|
|
||||||
needs_caching = not is_end
|
needs_caching = not is_end
|
||||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
if needs_caching and cache_length == 0:
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||||
needs_caching = False
|
needs_caching = False
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
if needs_caching and x.shape[2] >= cache_length:
|
||||||
|
needs_caching = False
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
|
|
||||||
x = torch.cat(pieces, dim=2)
|
x = torch.cat(pieces, dim=2)
|
||||||
del pieces
|
del pieces
|
||||||
del cached
|
del cached
|
||||||
|
|
||||||
if needs_caching:
|
if needs_caching:
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
elif is_end:
|
elif is_end:
|
||||||
self.temporal_cache_state[tid] = (None, True)
|
self.temporal_cache_state[tid] = (None, True)
|
||||||
|
|
||||||
|
|||||||
@ -233,10 +233,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||||
r"""The forward method of the `Encoder` class."""
|
|
||||||
|
|
||||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
@ -247,10 +244,14 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks:
|
||||||
sample = checkpoint_fn(down_block)(sample)
|
sample = checkpoint_fn(down_block)(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample)
|
sample = self.conv_out(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.latent_log_var == "uniform":
|
if self.latent_log_var == "uniform":
|
||||||
last_channel = sample[:, -1:, ...]
|
last_channel = sample[:, -1:, ...]
|
||||||
@ -282,9 +283,35 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||||
|
r"""The forward method of the `Encoder` class."""
|
||||||
|
|
||||||
|
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||||
|
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||||
|
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
samples = [sample[:, :, :1, :, :]]
|
||||||
|
if sample.shape[2] > 1:
|
||||||
|
chunk_t = max(2, max_chunk_size // frame_size)
|
||||||
|
if chunk_t < 4:
|
||||||
|
chunk_t = 2
|
||||||
|
elif chunk_t < 8:
|
||||||
|
chunk_t = 4
|
||||||
|
else:
|
||||||
|
chunk_t = (chunk_t // 8) * 8
|
||||||
|
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||||
|
for chunk_idx, chunk in enumerate(samples):
|
||||||
|
if chunk_idx == len(samples) - 1:
|
||||||
|
mark_conv3d_ended(self)
|
||||||
|
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||||
|
output = self._forward_chunk(chunk)
|
||||||
|
if output is not None:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
return torch_cat_if_needed(outputs, dim=2)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
#No encoder support so just flag the end so it doesnt use the cache.
|
|
||||||
mark_conv3d_ended(self)
|
|
||||||
try:
|
try:
|
||||||
return self.forward_orig(*args, **kwargs)
|
return self.forward_orig(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
@ -473,6 +500,17 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||||
|
ts, hs, ws, to = 1, 1, 1, 0
|
||||||
|
for block in self.up_blocks:
|
||||||
|
if isinstance(block, DepthToSpaceUpsample):
|
||||||
|
ts *= block.stride[0]
|
||||||
|
hs *= block.stride[1]
|
||||||
|
ws *= block.stride[2]
|
||||||
|
if block.stride[0] > 1:
|
||||||
|
to = to * block.stride[0] + 1
|
||||||
|
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||||
|
|
||||||
self.timestep_conditioning = timestep_conditioning
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
if timestep_conditioning:
|
if timestep_conditioning:
|
||||||
@ -494,11 +532,15 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
def decode_output_shape(self, input_shape):
|
||||||
|
c, (ts, hs, ws), to = self._output_scale
|
||||||
|
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
timestep: Optional[torch.Tensor] = None,
|
timestep: Optional[torch.Tensor] = None,
|
||||||
|
output_buffer: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
@ -540,7 +582,13 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
output = []
|
if output_buffer is None:
|
||||||
|
output_buffer = torch.empty(
|
||||||
|
self.decode_output_shape(sample.shape),
|
||||||
|
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
output_offset = [0]
|
||||||
|
|
||||||
max_chunk_size = get_max_chunk_size(sample.device)
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
|
||||||
def run_up(idx, sample_ref, ended):
|
def run_up(idx, sample_ref, ended):
|
||||||
@ -556,7 +604,10 @@ class Decoder(nn.Module):
|
|||||||
mark_conv3d_ended(self.conv_out)
|
mark_conv3d_ended(self.conv_out)
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
if sample is not None and sample.shape[2] > 0:
|
if sample is not None and sample.shape[2] > 0:
|
||||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
t = sample.shape[2]
|
||||||
|
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||||
|
output_offset[0] += t
|
||||||
return
|
return
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
up_block = self.up_blocks[idx]
|
||||||
@ -588,11 +639,8 @@ class Decoder(nn.Module):
|
|||||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||||
|
|
||||||
run_up(0, [sample], True)
|
run_up(0, [sample], True)
|
||||||
sample = torch.cat(output, dim=2)
|
|
||||||
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
return output_buffer
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -716,12 +764,25 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
self.temporal_cache_state = {}
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True):
|
||||||
if self.stride[0] == 2:
|
tid = threading.get_ident()
|
||||||
|
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||||
|
if cached_input is not None:
|
||||||
|
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||||
|
cached_input = None
|
||||||
|
|
||||||
|
if self.stride[0] == 2 and pad_first:
|
||||||
x = torch.cat(
|
x = torch.cat(
|
||||||
[x[:, :, :1, :, :], x], dim=2
|
[x[:, :, :1, :, :], x], dim=2
|
||||||
) # duplicate first frames for padding
|
) # duplicate first frames for padding
|
||||||
|
pad_first = False
|
||||||
|
|
||||||
|
if x.shape[2] < self.stride[0]:
|
||||||
|
cached_input = x
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
return None
|
||||||
|
|
||||||
# skip connection
|
# skip connection
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@ -736,15 +797,26 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
|
|
||||||
# conv
|
# conv
|
||||||
x = self.conv(x, causal=causal)
|
x = self.conv(x, causal=causal)
|
||||||
x = rearrange(
|
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||||
x,
|
if cached_x is not None:
|
||||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||||
p1=self.stride[0],
|
cached_x = None
|
||||||
p2=self.stride[1],
|
else:
|
||||||
p3=self.stride[2],
|
cached_x = x
|
||||||
)
|
x = None
|
||||||
|
|
||||||
x = x + x_in
|
if x is not None:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||||
|
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -1077,6 +1149,8 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
|
comfy_has_chunked_io = True
|
||||||
|
|
||||||
def __init__(self, version=0, config=None):
|
def __init__(self, version=0, config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1219,14 +1293,15 @@ class VideoVAE(nn.Module):
|
|||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x, device=None):
|
||||||
frames_count = x.shape[2]
|
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||||
if ((frames_count - 1) % 8) != 0:
|
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode_output_shape(self, input_shape):
|
||||||
|
return self.decoder.decode_output_shape(input_shape)
|
||||||
|
|
||||||
|
def decode(self, x, output_buffer=None):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||||
|
|||||||
@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
|
|||||||
if (destination.device.type != "cpu"
|
if (destination.device.type != "cpu"
|
||||||
or file_obj is None
|
or file_obj is None
|
||||||
or threading.get_ident() != info.thread_id
|
or threading.get_ident() != info.thread_id
|
||||||
or destination.numel() * destination.element_size() < info.size):
|
or destination.numel() * destination.element_size() < info.size
|
||||||
|
or tensor.numel() * tensor.element_size() != info.size
|
||||||
|
or tensor.storage_offset() != 0
|
||||||
|
or not tensor.is_contiguous()):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if info.size == 0:
|
if info.size == 0:
|
||||||
|
|||||||
@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
|
|||||||
def text_encoder_device():
|
def text_encoder_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
|
||||||
if should_use_fp16(prioritize_performance=False):
|
if should_use_fp16(prioritize_performance=False):
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
32
comfy/sd.py
32
comfy/sd.py
@ -455,7 +455,7 @@ class VAE:
|
|||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.pad_channel_value = None
|
self.pad_channel_value = None
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
@ -951,12 +951,23 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
|
# Pre-allocate output for VAEs that support direct buffer writes
|
||||||
|
preallocated = False
|
||||||
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
preallocated = True
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
if preallocated:
|
||||||
if pixel_samples is None:
|
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
else:
|
||||||
pixel_samples[x:x+batch_number] = out
|
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
pixel_samples[x:x+batch_number].copy_(out)
|
||||||
|
del out
|
||||||
|
self.process_output(pixel_samples[x:x+batch_number])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
@ -1027,8 +1038,13 @@ class VAE:
|
|||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = None
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
|||||||
out, pooled = o[:2]
|
out, pooled = o[:2]
|
||||||
|
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled
|
first_pooled = pooled
|
||||||
|
|
||||||
@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
|||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||||
else:
|
else:
|
||||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||||
|
|
||||||
if len(o) > 2:
|
if len(o) > 2:
|
||||||
extra = {}
|
extra = {}
|
||||||
for k in o[2]:
|
for k in o[2]:
|
||||||
v = o[2][k]
|
v = o[2][k]
|
||||||
if k == "attention_mask":
|
if k == "attention_mask":
|
||||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||||
extra[k] = v
|
extra[k] = v
|
||||||
|
|
||||||
r = r + (extra,)
|
r = r + (extra,)
|
||||||
|
|||||||
@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out = output[b:b+1].zero_()
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
upscaled.append(round(get_pos(d, pos)))
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||||
|
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||||
@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
output[b:b+1] = out/out_div
|
out.div_(out_div)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
|||||||
@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
|
|||||||
inlineData: GeminiInlineData | None = Field(None)
|
inlineData: GeminiInlineData | None = Field(None)
|
||||||
fileData: GeminiFileData | None = Field(None)
|
fileData: GeminiFileData | None = Field(None)
|
||||||
text: str | None = Field(None)
|
text: str | None = Field(None)
|
||||||
|
thought: bool | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiTextPart(BaseModel):
|
class GeminiTextPart(BaseModel):
|
||||||
|
|||||||
@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
|||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$r := widgets.resolution;
|
$r := widgets.resolution;
|
||||||
$isFlash := $contains($m, "nano banana 2");
|
$isFlash := $contains($m, "nano banana 2");
|
||||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||||
@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||||||
return "\n".join([part.text for part in parts])
|
return "\n".join([part.text for part in parts])
|
||||||
|
|
||||||
|
|
||||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||||
image_tensors: list[Input.Image] = []
|
image_tensors: list[Input.Image] = []
|
||||||
parts = get_parts_by_type(response, "image/*")
|
parts = get_parts_by_type(response, "image/*")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
|
if (part.thought is True) != thought:
|
||||||
|
continue
|
||||||
if part.inlineData:
|
if part.inlineData:
|
||||||
image_data = base64.b64decode(part.inlineData.data)
|
image_data = base64.b64decode(part.inlineData.data)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
IO.String.Output(),
|
IO.String.Output(),
|
||||||
|
IO.Image.Output(
|
||||||
|
display_name="thought_image",
|
||||||
|
tooltip="First image from the model's thinking process. "
|
||||||
|
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(
|
||||||
|
await get_image_from_response(response),
|
||||||
|
get_text_from_response(response),
|
||||||
|
await get_image_from_response(response, thought=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.41.20
|
comfyui-frontend-package==1.41.21
|
||||||
comfyui-workflow-templates==0.9.26
|
comfyui-workflow-templates==0.9.26
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user