diff --git a/README.md b/README.md index ed857df9f..bae955b1b 100644 --- a/README.md +++ b/README.md @@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step 1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) 2. Launch ComfyUI by running `python main.py` + +## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4) + +**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI. + +### Setup + +1. Install the manager dependencies: + ```bash + pip install -r manager_requirements.txt + ``` + +2. Enable the manager with the `--enable-manager` flag when running ComfyUI: + ```bash + python main.py --enable-manager + ``` + +### Command Line Options + +| Flag | Description | +|------|-------------| +| `--enable-manager` | Enable ComfyUI-Manager | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | + + # Running ```python main.py``` diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 613b0f7c7..b224306da 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -58,8 +58,13 @@ class InternalRoutes: return web.json_response({"error": "Invalid directory type"}, status=400) directory = get_directory_by_type(directory_type) + + def is_visible_file(entry: os.DirEntry) -> bool: + """Filter out hidden files (e.g., .DS_Store on macOS).""" + return entry.is_file() and not entry.name.startswith('.') + sorted_files = sorted( - (entry for entry in os.scandir(directory) if entry.is_file()), + (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) return web.json_response([entry.name for entry in sorted_files], status=200) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 041f380f9..5c412d1c2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -51,26 +51,36 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): self.index_list = index_list self.context_length = len(index_list) self.dim = dim + self.total_frames = total_frames + self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) - def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: dim = self.dim if dim == 0 and full.shape[dim] == 1: return full - idx = [slice(None)] * dim + [self.index_list] - return full[idx].to(device) + idx = tuple([slice(None)] * dim + [self.index_list]) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: if dim is None: dim = self.dim - idx = [slice(None)] * dim + [self.index_list] + idx = tuple([slice(None)] * dim + [self.index_list]) full[idx] += to_add return full + def get_region_index(self, num_regions: int) -> int: + region_idx = int(self.center_ratio * num_regions) + return min(max(region_idx, 0), num_regions - 1) + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -94,7 +104,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC): self.closed_loop = closed_loop self.dim = dim self._step = 0 + self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] + self.split_conds_to_windows = split_conds_to_windows self.callbacks = {} def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: # for now, assume first dim is batch - should have stored on BaseModel in actual implementation if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True return False @@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC): return None # reuse or resize cond items to match context requirements resized_cond = [] + # if multiple conds, split based on primary region + if self.split_conds_to_windows and len(cond_in) > 1: + region = window.get_region_index(len(cond_in)) + logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() @@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC): # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, torch.Tensor): - if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # Handle audio_embed (temporal dim is 1) + elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + audio_cond = cond_value.cond + if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): - new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ + (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length @@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: raise Exception("No sample_sigmas matched current timestep; something went wrong.") @@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): @@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC): prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) # account for dims of tensors - idx_window = [slice(None)] * self.dim + [idx] - pos_window = [slice(None)] * self.dim + [pos] + idx_window = tuple([slice(None)] * self.dim + [idx]) + pos_window = tuple([slice(None)] * self.dim + [pos]) # apply new values conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight biases_final[i][idx] = bias_total + bias @@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher): ) +def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs): + model_options = extra_args.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is None: + raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") + if not handler.freenoise: + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + + +def create_sampler_sample_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, + "ContextWindows_sampler_sample", + _sampler_sample_wrapper + ) + + def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int): for i in range(len(window)): # 2) add end_delta to each val to slide windows to end window[i] = window[i] + end_delta + + +# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): + logging.info("Context windows: Applying FreeNoise") + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] + delta = context_length - context_overlap + + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: + break + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + + return noise diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index e643b4414..70d173889 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams): nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] - + use_x0: bool class ChromaRadiance(Chroma): """ @@ -159,6 +159,9 @@ class ChromaRadiance(Chroma): self.skip_dit = [] self.lite = False + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -276,6 +279,12 @@ class ChromaRadiance(Chroma): params_dict |= overrides return params.__class__(**params_dict) + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def _forward( self, x: Tensor, @@ -316,4 +325,11 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 2749c53f5..55ab550f8 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -43,6 +43,7 @@ class HunyuanVideoParams: meanflow: bool use_cond_type_embedding: bool vision_in_dim: int + meanflow_sum: bool class SelfAttentionRef(nn.Module): @@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module): timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) - vec = (vec + vec_r) / 2 + vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2 if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py new file mode 100644 index 000000000..1509de2f8 --- /dev/null +++ b/comfy/ldm/kandinsky5/model.py @@ -0,0 +1,413 @@ +import torch +from torch import nn +import math + +import comfy.ldm.common_dit +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.layers import EmbedND + +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + +def apply_scale_shift_norm(norm, x, scale, shift): + return torch.addcmul(shift, norm(x), scale + 1.0) + +def apply_gate_sum(x, out, gate): + return torch.addcmul(x, gate, out) + +def get_shift_scale_gate(params): + shift, scale, gate = torch.chunk(params, 3, dim=-1) + return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) + +def get_freqs(dim, max_period=10000.0): + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.SiLU() + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, timestep, dtype): + args = torch.outer(timestep, self.freqs.to(device=timestep.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + x = x.movedim(1, -1) # B C T H W -> B T H W C + B, T, H, W, dim = x.shape + pt, ph, pw = self.patch_size + + x = x.view( + B, + T // pt, pt, + H // ph, ph, + W // pw, pw, + dim, + ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) + + return self.in_layer(x) + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params, operation_settings=None): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class SelfAttention(nn.Module): + def __init__(self, num_channels, head_dim, operation_settings=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + self.head_dim = head_dim + + operations = operation_settings.get("operations") + self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 + + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) + chunks = [] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + + +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) + return q, k, v + + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.GELU() + self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 + + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) + + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings) + operations = operation_settings.get("operations") + self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, visual_embed, time_embed): + B, T, H, W, _ = visual_embed.shape + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + scale = scale[:, None, None, None, :] + shift = shift[:, None, None, None, :] + visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift) + x = self.out_layer(visual_embed) + + out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2]) + x = x.view( + B, T, H, W, + out_dim, + self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings) + operations = operation_settings.get("operations") + + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, x, time_embed, freqs, transformer_options={}): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = get_shift_scale_gate(self_attn_params) + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = self.self_attention(out, freqs, transformer_options=transformer_options) + x = apply_gate_sum(x, out, gate) + + shift, scale, gate = get_shift_scale_gate(ff_params) + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = self.feed_forward(out) + x = apply_gate_sum(x, out, gate) + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings) + + operations = operation_settings.get("operations") + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) + # self attention + shift, scale, gate = get_shift_scale_gate(self_attn_params) + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # cross attention + shift, scale, gate = get_shift_scale_gate(cross_attn_params) + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # feed forward + shift, scale, gate = get_shift_scale_gate(ff_params) + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = self.feed_forward(visual_out) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + return visual_embed + + +class Kandinsky5(nn.Module): + def __init__( + self, + in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512, + model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32, + axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0), + dtype=None, device=None, operations=None, **kwargs + ): + super().__init__() + head_dim = sum(axes_dims) + self.rope_scale_factor = rope_scale_factor + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_embed_dim = visual_embed_dim + self.dtype = dtype + self.device = device + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings) + + self.text_transformer_blocks = nn.ModuleList( + [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)] + ) + + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings) + + self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) + + def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): + steps = seq_len if steps is None else steps + seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) + seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) + freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) + return freqs + + def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + else: + rope_scale_factor = self.rope_scale_factor + if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions + if h * w >= 14080: + rope_scale_factor = (1.0, 3.16, 3.16) + + t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0 + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) + return freqs + + def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + context = self.text_embeddings(context) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) + + for block in self.text_transformer_blocks: + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) + + visual_embed = self.visual_embeddings(x) + visual_shape = visual_embed.shape[:-1] + visual_embed = visual_embed.flatten(1, -2) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + + visual_embed = visual_embed.reshape(*visual_shape, -1) + return self.out_layer(visual_embed, time_embed) + + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + original_dims = x.ndim + if original_dims == 4: + x = x.unsqueeze(2) + bs, c, t_len, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + if original_dims == 4: + out = out.squeeze(2) + return out + + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f1c1a0ec3..c47df49ca 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -377,6 +377,7 @@ class NextDiT(nn.Module): z_image_modulation=False, time_scale=1.0, pad_tokens_multiple=None, + clip_text_dim=None, image_model=None, device=None, dtype=None, @@ -447,6 +448,31 @@ class NextDiT(nn.Module): ), ) + self.clip_text_pooled_proj = None + + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_text_embed = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024) + clip_text_dim, + min(dim, 1024), + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.layers = nn.ModuleList( [ JointTransformerBlock( @@ -585,8 +611,16 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + patches = transformer_options.get("patches", {}) - transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) diff --git a/comfy/lora.py b/comfy/lora.py index 3a9077869..2ed0acb9d 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -320,8 +320,16 @@ def model_lora_keys_unet(model, key_map={}): to = diffusers_keys[k] key_lora = k[:-len(".weight")] key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if isinstance(model, comfy.model_base.Kandinsky5): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285e..6b8a8454d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model +import comfy.ldm.kandinsky5.model import comfy.model_management import comfy.patcher_extension @@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - - # Save mixed precision metadata - if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: - metadata = { - "format_version": "1.0", - "layers": self.model_config.layer_quant_config - } - unet_state_dict["_quantization_metadata"] = metadata - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: @@ -1121,6 +1110,10 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + clip_text_pooled = kwargs["pooled_output"] # Newbie + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + return out class WAN21(BaseModel): @@ -1642,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(False) return out + +class Kandinsky5(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + device = kwargs["device"] + image = torch.zeros_like(noise) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace)) + + return out + +class Kandinsky5Image(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + return None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0517e61..dd6a703f6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,20 +6,6 @@ import math import logging import torch - -def detect_layer_quantization(metadata): - quant_key = "_quantization_metadata" - if metadata is not None and quant_key in metadata: - quant_metadata = metadata.pop(quant_key) - quant_metadata = json.loads(quant_metadata) - if isinstance(quant_metadata, dict) and "layers" in quant_metadata: - logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") - return quant_metadata["layers"] - else: - raise ValueError("Invalid quantization metadata format") - return None - - def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -194,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_cond_type_embedding"] = False if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["meanflow_sum"] = True else: dit_config["vision_in_dim"] = None + dit_config["meanflow_sum"] = False return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) @@ -271,6 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True + else: + dit_config["use_x0"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys @@ -437,6 +429,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 + ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) + if ctd_weight is not None: + dit_config["clip_text_dim"] = ctd_weight.shape[0] elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 @@ -625,6 +620,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 + dit_config = {} + model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["model_dim"] = model_dim + if model_dim in [4096, 2560]: # pro video and lite image + dit_config["axes_dims"] = (32, 48, 48) + if model_dim == 2560: # lite image + dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) + elif model_dim == 1792: # lite video + dit_config["axes_dims"] = (16, 24, 24) + dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["image_model"] = "kandinsky5" + dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] + dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') + dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -767,22 +780,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True - # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(metadata) - if layer_quant_config: - model_config.layer_quant_config = layer_quant_config - logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") return model_config diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..40717b1e4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1492,6 +1492,20 @@ def extended_fp16_support(): return True +LORA_COMPUTE_DTYPES = {} +def lora_compute_dtype(device): + dtype = LORA_COMPUTE_DTYPES.get(device, None) + if dtype is not None: + return dtype + + if should_use_fp16(device): + dtype = torch.float16 + else: + dtype = torch.float32 + + LORA_COMPUTE_DTYPES[device] = dtype + return dtype + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3dcac3eef..a486c2723 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -126,36 +127,23 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight, inplace=False) + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - - out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out - -#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 -LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 def low_vram_patch_estimate_vram(model, key): weight, set_func, convert_func = get_key_weight(model, key) if weight is None: return 0 - return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + if model_dtype is None: + model_dtype = weight.dtype + + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): set_func = None @@ -630,10 +618,11 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: - temp_weight = weight.to(torch.float32, copy=True) + temp_weight = weight.to(temp_dtype, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) @@ -677,12 +666,18 @@ class ModelPatcher: module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append((module_offload_mem, module_mem, n, m, params)) return loading @@ -777,6 +772,8 @@ class ModelPatcher: key = "{}.{}".format(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) + if comfy.model_management.is_device_cuda(device_to): + torch.cuda.synchronize() logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -933,7 +930,7 @@ class ModelPatcher: patch_counter += 1 cast_weight = True - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False diff --git a/comfy/ops.py b/comfy/ops.py index eae434e68..6f34d50fc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,7 +22,7 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib +import json def run_every_op(): if torch.compiler.is_compiling(): @@ -93,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None - if offload_stream is not None: - wf_context = offload_stream - if hasattr(wf_context, "as_context"): - wf_context = wf_context.as_context(offload_stream) - else: - wf_context = contextlib.nullcontext() - non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 @@ -422,22 +415,12 @@ def fp8_linear(self, input): if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) - - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - else: - scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! @@ -458,7 +441,7 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - if not self.training: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: try: out = fp8_linear(self, input) if out is not None: @@ -471,59 +454,6 @@ class fp8_ops(manual_cast): uncast_bias_weight(self, weight, bias, offload_stream) return x -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - - if weight.numel() < input.numel(): #TODO: optimize - x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -550,9 +480,9 @@ if CUBLAS_IS_AVAILABLE: from .quant_ops import QuantizedTensor, QUANT_ALGOS -def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): class MixedPrecisionOps(manual_cast): - _layer_quant_config = layer_quant_config + _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm @@ -595,27 +525,38 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful manually_loaded_keys = [weight_key] - if layer_name not in MixedPrecisionOps._layer_quant_config: + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - qconfig = QUANT_ALGOS[quant_format] + qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) + if scale is not None: + scale = scale.to(device) layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), + 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, 'block_size': qconfig.get("group_size", None), } - if layout_params['scale'] is not None: + + if scale is not None: manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), requires_grad=False ) @@ -624,7 +565,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -633,6 +574,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if key in missing_keys: missing_keys.remove(key) + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + return sd + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -648,9 +599,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -661,7 +611,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -670,17 +620,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful assert inplace_update is False # TODO: eventually remove the inplace_update stuff self.weight = torch.nn.Parameter(weight, requires_grad=False) + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + return MixedPrecisionOps -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular - if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860c..cd96541d7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor): def is_contiguous(self, *arg, **kwargs): return self._qdata.is_contiguous(*arg, **kwargs) + def storage(self): + return self._qdata.storage() + # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ============================================================================== @@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn): def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_dtype is not None and target_dtype != qt.dtype: - logging.warning( - f"QuantizedTensor: dtype conversion requested to {target_dtype}, " - f"but not supported for quantized tensors. Ignoring dtype." - ) - if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " @@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt @@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs): # Copy from another quantized tensor qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype else: # Copy from regular tensor - just copy raw data qt_dest._qdata.copy_(src) @@ -397,17 +398,23 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale is None: - scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + if isinstance(scale, str) and scale == "recalculate": + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) - if inplace_ops: - tensor *= (1.0 / scale).to(tensor.dtype) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) else: - tensor = tensor * (1.0 / scale).to(tensor.dtype) + scale = torch.ones((), device=tensor.device, dtype=torch.float32) if stochastic_rounding > 0: tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) diff --git a/comfy/sd.py b/comfy/sd.py index 03bdb33d5..1cad98aef 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis +import comfy.text_encoders.kandinsky5 import comfy.model_patcher import comfy.lora @@ -98,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: return params = target.params.copy() @@ -126,9 +127,32 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None + if len(state_dict) > 0: + if isinstance(state_dict, list): + for c in state_dict: + m, u = self.load_sd(c) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected: {}".format(u)) + else: + m, u = self.load_sd(state_dict, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None @@ -525,8 +549,10 @@ class VAE: ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] - self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype) + + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: @@ -745,6 +771,8 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = None do_tile = False + if self.latent_dim == 2 and samples_in.ndim == 5: + samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -962,16 +990,16 @@ class CLIPType(Enum): HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 OVIS = 21 + KANDINSKY5 = 22 + KANDINSKY5_IMAGE = 23 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - if metadata is not None: - quant_metadata = metadata.get("_quantization_metadata", None) - if quant_metadata is not None: - sd["_quantization_metadata"] = quant_metadata + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1088,7 +1116,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1112,7 +1140,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1141,7 +1169,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) @@ -1169,7 +1197,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1212,6 +1240,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO_15: clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer + elif clip_type == CLIPType.KANDINSKY5: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer + elif clip_type == CLIPType.KANDINSKY5_IMAGE: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -1224,19 +1258,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: - if "_quantization_metadata" in c: - c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) - for c in clip_data: - m, u = clip.load_sd(c) - if len(m) > 0: - logging.warning("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected: {}".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) return clip def load_gligen(ckpt_path): @@ -1295,6 +1320,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1303,18 +1332,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1332,22 +1365,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") @@ -1394,6 +1438,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): if len(temp_sd) > 0: sd = temp_sd + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1424,7 +1471,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1432,12 +1479,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if model_config.layer_quant_config is not None: + if model_config.quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True @@ -1476,6 +1526,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 503a51843..962948dae 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None - quantization_metadata = model_options.get("quantization_metadata", None) + quant_config = model_options.get("quantization_metadata", None) if operations is None: - layer_quant_config = None - if quantization_metadata is not None: - layer_quant_config = json.loads(quantization_metadata).get("layers", None) - - if layer_quant_config is not None: - operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) - logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") + if quant_config is not None: + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: - # Fallback to scaled_fp8_ops for backward compatibility - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) - else: - operations = comfy.ops.manual_cast + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers diff --git a/comfy/supported_models.py b/comfy/supported_models.py index afd97160b..ef8c75c09 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image from . import supported_models_base @@ -540,7 +541,7 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - memory_usage_factor = 1.2 + memory_usage_factor = 1.6 text_encoder_key_prefix = ["text_encoders."] @@ -1025,7 +1026,7 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 1.7 + memory_usage_factor = 2.0 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] @@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] +class Kandinsky5(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + } + + sampling_settings = { + "shift": 10.0, + } + + unet_extra_config = {} + latent_format = latent_formats.HunyuanVideo + + memory_usage_factor = 1.1 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +class Kandinsky5Image(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 64, + } + + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.1 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5Image(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e4bd74514..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -17,6 +17,7 @@ """ import torch +import logging from . import model_base from . import utils from . import latent_formats @@ -49,8 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None - layer_quant_config = None # Per-layer quantization configuration for mixed precision + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod @@ -118,3 +118,7 @@ class BASE: def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype + + def __getattr__(self, name): + logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) + return None diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242..448381fa9 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ from transformers import T5TokenizerFast class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 99f4812bb..21d93d757 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ @@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra -def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Flux2TEModel_(Flux2TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata if pruned: model_options = model_options.copy() diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2..5daea8135 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784..600b34480 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..cd198036c 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 0110517bb..a9a6c525e 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast import torch import os import numbers - +import comfy.utils def llama_detect(state_dict, prefix=""): out = {} @@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype - - if "_quantization_metadata" in state_dict: - out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py new file mode 100644 index 000000000..be086458c --- /dev/null +++ b/comfy/text_encoders/kandinsky5.py @@ -0,0 +1,68 @@ +from comfy import sd1_clip +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from .llama import Qwen25_7BVLI + + +class Kandinsky5Tokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + + return out + + +class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Kandinsky5TEModel(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) + + return cond, l_pooled, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + self.clip_l.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + self.clip_l.reset_clip_options() + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return super().load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Kandinsky5TEModel_(Kandinsky5TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Kandinsky5TEModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index fd986e2c1..7a6cfdab2 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": @@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4..50aa4121f 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 81c9bd51c..5754424d2 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel): return out, pooled, {} -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class OvisTEModel_(OvisTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 5f383de07..e5e5f18be 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index c0d32a6ef..5c14dec23 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db..8b153c72b 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -6,14 +6,15 @@ import torch import os import comfy.model_management import logging +import comfy.utils class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28..164a57edd 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index bb9273b20..19adde0b7 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class ZImageTEModel_(ZImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/utils.py b/comfy/utils.py index 37485e497..9dc0d76ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,6 +29,7 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +import json MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -802,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +ATTR_UNSET={} + def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value): @@ -1194,3 +1200,68 @@ def unpack_latents(combined_latent, latent_shapes): else: output_tensors = combined_latent return output_tensors + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + + return state_dict, metadata diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 0fa01d1e7..35e1ac853 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput -from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents -from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input_impl import VideoFromFile, VideoFromComponents +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 @@ -80,7 +80,7 @@ class ComfyExtension(ABC): async def on_load(self) -> None: """ Called when an extension is loaded. - This should be used to initialize any global resources neeeded by the extension. + This should be used to initialize any global resources needed by the extension. """ @abstractmethod diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 87c81d73a..e634a0311 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -4,7 +4,7 @@ from fractions import Fraction from typing import Optional, Union, IO import io import av -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): """ diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a4cd3737d..ea35c6062 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -3,14 +3,14 @@ from av.container import InputContainer from av.subtitles.stream import SubtitleStream from fractions import Fraction from typing import Optional -from comfy_api.latest._input import AudioInput, VideoInput +from .._input import AudioInput, VideoInput import av import io import json import numpy as np import math import torch -from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents def container_to_output_format(container_format: str | None) -> str | None: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 866c3e0eb..79217c813 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from comfy_api.latest._resources import Resources, ResourcesLocal +from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker from ._util import MESH, VOXEL @@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO): '''Used by WAN Camera.''' time_dim_concat: NotRequired[torch.Tensor] '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' CondList = list[tuple[torch.Tensor, PooledDict]] Type = CondList @@ -1813,7 +1815,7 @@ class NodeOutput(_NodeOutputInternal): ui = data["ui"] if "expand" in data: expand = data["expand"] - return cls(args=args, ui=ui, expand=expand) + return cls(*args, ui=ui, expand=expand) def __getitem__(self, index) -> Any: return self.args[index] diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 5a75a3aae..2babe209a 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -22,7 +22,7 @@ import folder_paths # used for image preview from comfy.cli_args import args -from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput +from ._io import ComfyNode, FolderType, Image, _UIOutput class SavedResult(dict): diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index c3e3d8e3a..fd3b5a510 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from comfy_api.latest._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput class VideoCodec(str, Enum): AUTO = "auto" diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py new file mode 100644 index 000000000..77cd76f9b --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -0,0 +1,144 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + size: str | None = Field(None) + seed: int | None = Field(0, ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: str | None = Field("adaptive") + seed: int | None = Field(..., ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str = Field("url") + image: list[str] | None = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: TaskStatusError | None = Field(None) + content: TaskStatusResult | None = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, + "seedance-1-0-pro-fast-251015": { + "480p": 50, + "720p": 65, + "1080p": 100, + }, +} diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index a380ecc86..f8edc38c9 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel): description="A list of ordered parts that make up a single message. " "Different parts may have different IANA MIME types.", ) - role: GeminiRole = Field( - ..., - description="The identity of the entity that creates the message. " - "The following values are supported: " - "user: This indicates that the message is sent by a real person, typically a user-generated message. " - "model: This indicates that the message is generated by the model. " - "The model value is used to insert messages from model into the conversation during multi-turn conversations. " - "For non-multi-turn conversations, this field can be left blank or unset.", - ) + role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.") class GeminiFunctionDeclaration(BaseModel): diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index 8328d1aa4..23ca725b7 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -85,7 +85,7 @@ class Response1(BaseModel): raiMediaFilteredReasons: Optional[list[str]] = Field( None, description='Reasons why media was filtered by responsible AI policies' ) - videos: Optional[list[Video]] = None + videos: Optional[list[Video]] = Field(None) class VeoGenVidPollResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index caced471e..57c0218d0 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,13 +1,27 @@ import logging import math -from enum import Enum -from typing import Literal, Optional, Union import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_api import ( + RECOMMENDED_PRESETS, + RECOMMENDED_PRESETS_SEEDREAM_4, + VIDEO_TASKS_EXECUTION_TIME, + Image2ImageTaskCreationRequest, + Image2VideoTaskCreationRequest, + ImageTaskCreationResponse, + Seedream4Options, + Seedream4TaskCreationRequest, + TaskCreationResponse, + TaskImageContent, + TaskImageContentUrl, + TaskStatusResponse, + TaskTextContent, + Text2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} -class Text2ImageModelName(str, Enum): - seedream_3 = "seedream-3-0-t2i-250415" - - -class Image2ImageModelName(str, Enum): - seededit_3 = "seededit-3-0-i2i-250628" - - -class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-t2v-250428" - - -class Image2VideoModelName(str, Enum): - """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-i2v-250428" - - -class Text2ImageTaskCreationRequest(BaseModel): - model: Text2ImageModelName = Text2ImageModelName.seedream_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - size: Optional[str] = Field(None) - seed: Optional[int] = Field(0, ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Image2ImageTaskCreationRequest(BaseModel): - model: Image2ImageModelName = Image2ImageModelName.seededit_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: Optional[str] = Field("adaptive") - seed: Optional[int] = Field(..., ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Seedream4Options(BaseModel): - max_images: int = Field(15) - - -class Seedream4TaskCreationRequest(BaseModel): - model: str = Field("seedream-4-0-250828") - prompt: str = Field(...) - response_format: str = Field("url") - image: Optional[list[str]] = Field(None, description="Image URLs") - size: str = Field(...) - seed: int = Field(..., ge=0, le=2147483647) - sequential_image_generation: str = Field("disabled") - sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) - - -class ImageTaskCreationResponse(BaseModel): - model: str = Field(...) - created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") - data: list = Field([], description="Contains information about the generated image(s).") - error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") - - -class TaskTextContent(BaseModel): - type: str = Field("text") - text: str = Field(...) - - -class TaskImageContentUrl(BaseModel): - url: str = Field(...) - - -class TaskImageContent(BaseModel): - type: str = Field("image_url") - image_url: TaskImageContentUrl = Field(...) - role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) - - -class Text2VideoTaskCreationRequest(BaseModel): - model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro - content: list[TaskTextContent] = Field(..., min_length=1) - - -class Image2VideoTaskCreationRequest(BaseModel): - model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro - content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) - - -class TaskCreationResponse(BaseModel): - id: str = Field(...) - - -class TaskStatusError(BaseModel): - code: str = Field(...) - message: str = Field(...) - - -class TaskStatusResult(BaseModel): - video_url: str = Field(...) - - -class TaskStatusResponse(BaseModel): - id: str = Field(...) - model: str = Field(...) - status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) - error: Optional[TaskStatusError] = Field(None) - content: Optional[TaskStatusResult] = Field(None) - - -RECOMMENDED_PRESETS = [ - ("1024x1024 (1:1)", 1024, 1024), - ("864x1152 (3:4)", 864, 1152), - ("1152x864 (4:3)", 1152, 864), - ("1280x720 (16:9)", 1280, 720), - ("720x1280 (9:16)", 720, 1280), - ("832x1248 (2:3)", 832, 1248), - ("1248x832 (3:2)", 1248, 832), - ("1512x648 (21:9)", 1512, 648), - ("2048x2048 (1:1)", 2048, 2048), - ("Custom", None, None), -] - -RECOMMENDED_PRESETS_SEEDREAM_4 = [ - ("2048x2048 (1:1)", 2048, 2048), - ("2304x1728 (4:3)", 2304, 1728), - ("1728x2304 (3:4)", 1728, 2304), - ("2560x1440 (16:9)", 2560, 1440), - ("1440x2560 (9:16)", 1440, 2560), - ("2496x1664 (3:2)", 2496, 1664), - ("1664x2496 (2:3)", 1664, 2496), - ("3024x1296 (21:9)", 3024, 1296), - ("4096x4096 (1:1)", 4096, 4096), - ("Custom", None, None), -] - -# The time in this dictionary are given for 10 seconds duration. -VIDEO_TASKS_EXECUTION_TIME = { - "seedance-1-0-lite-t2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-lite-i2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-pro-250528": { - "480p": 70, - "720p": 85, - "1080p": 115, - }, -} - - def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" @@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: return response.data[0]["url"] -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: - """Returns the video URL from the task status response if it exists.""" - if hasattr(response, "content") and response.content: - return response.content.video_url - return None - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode): category="api node/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Text2ImageModelName, - default=Text2ImageModelName.seedream_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), IO.String.Input( "prompt", multiline=True, @@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Image2ImageModelName, - default=Image2ImageModelName.seededit_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), IO.Image.Input( "image", tooltip="The base image to edit", @@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, seed: int, guidance_scale: float, @@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedream-4-0-250828"], + options=["seedream-4-5-251128", "seedream-4-0-250828"], tooltip="Model name", ), IO.String.Input( @@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor = None, + image: Input.Image | None = None, size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], width: int = 2048, height: int = 2048, @@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode): raise ValueError( f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) + out_num_pixels = w * h + mp_provided = out_num_pixels / 1_000_000.0 + if "seedream-4-5" in model and out_num_pixels < 3686400: + raise ValueError( + f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"but {mp_provided:.2f}MP provided." + ) + if "seedream-4-0" in model and out_num_pixels < 921600: + raise ValueError( + f"Minimum image resolution that the selected model can generate is 0.92MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") @@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Text2VideoModelName, - default=Text2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Image2VideoModelName, - default=Image2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor, + image: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[model.value for model in Image2VideoModelName], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): cls, model: str, prompt: str, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): cls, model: str, prompt: str, - images: torch.Tensor, + images: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): async def process_video_task( cls: type[IO.ComfyNode], - payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - estimated_duration: Optional[int], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, ) -> IO.NodeOutput: initial_response = await sync_op( cls, @@ -1085,7 +934,7 @@ async def process_video_task( estimated_duration=estimated_duration, response_model=TaskStatusResponse, ) - return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) def raise_if_text_params(prompt: str, text_params: list[str]) -> None: diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 08f7b0f64..ad0f4b4d1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -13,8 +13,7 @@ import torch from typing_extensions import override import folder_paths -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.gemini_api import ( GeminiContent, GeminiFileData, @@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import ( GeminiMimeType, GeminiPart, GeminiRole, + GeminiSystemInstructionContent, + GeminiTextPart, Modality, ) from comfy_api_nodes.util import ( @@ -43,6 +44,14 @@ from comfy_api_nodes.util import ( GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_IMAGE_SYS_PROMPT = ( + "You are an expert image-generation engine. You must ALWAYS produce an image.\n" + "Interpret all user input—regardless of " + "format, intent, or abstraction—as literal visual directives for image composition.\n" + "If a prompt is conversational or lacks specific visual details, " + "you must creatively invent a concrete visual scenario that depicts the concept.\n" + "Prioritize generating the visual representation above any text, formatting, or conversational requests." +) class GeminiModel(str, Enum): @@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum): async def create_image_parts( cls: type[IO.ComfyNode], - images: torch.Tensor, + images: Input.Image, image_limit: int = 0, ) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] @@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: - image_tensors: list[torch.Tensor] = [] +def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: + image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: image_data = base64.b64decode(part.inlineData.data) @@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.String.Output(), @@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode): def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: """Convert video input to Gemini API compatible parts.""" - base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) return [ GeminiPart( inlineData=GeminiInlineData( @@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, audio: Input.Audio | None = None, video: Input.Video | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode): if files is not None: parts.extend(files) - # Create response + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode): role=GeminiRole.user, parts=parts, ) - ] + ], + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode): "'IMAGE+TEXT' to return both the generated image and a text response.", optional=True, ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode): prompt: str, model: str, seed: int, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", response_modalities: str = "IMAGE+TEXT", + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) parts: list[GeminiPart] = [GeminiPart(text=prompt)] @@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode): if files is not None: parts.extend(files) + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, @@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode): tooltip="Optional file(s) to use as context for the model. " "Accepts inputs from the Gemini Generate Content Input Files node.", ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), ], outputs=[ IO.Image.Output(), @@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode): aspect_ratio: str, resolution: str, response_modalities: str, - images: torch.Tensor | None = None, + images: Input.Image | None = None, files: list[GeminiPart] | None = None, + system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode): if aspect_ratio != "auto": image_config.aspectRatio = aspect_ratio + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + response = await sync_op( cls, ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), @@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, ), + systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 6c840dc47..a2cc87d84 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2056,7 +2056,7 @@ class KlingExtension(ComfyExtension): OmniProImageToVideoNode, OmniProVideoToVideoNode, OmniProEditVideoNode, - # OmniProImageNode, # need support from backend + OmniProImageNode, ] diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0b757a62b..7e61560dc 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -1,12 +1,9 @@ from io import BytesIO -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.util import ( ApiEndpoint, get_number_of_images, @@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel): model: str = Field(...) duration: int = Field(...) resolution: str = Field(...) - fps: Optional[int] = Field(25) - generate_audio: Optional[bool] = Field(True) - image_uri: Optional[str] = Field(None) + fps: int | None = Field(25) + generate_audio: bool | None = Field(True) + image_uri: str | None = Field(None) class TextToVideoNode(IO.ComfyNode): @@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class ImageToVideoNode(IO.ComfyNode): @@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, model: str, prompt: str, duration: int, @@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode): as_binary=True, max_retries=1, ) - return IO.NodeOutput(VideoFromFile(BytesIO(response))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) class LtxvApiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7c31d95b3..2771e4790 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,11 +1,8 @@ import logging -from typing import Optional -import torch from typing_extensions import override -from comfy_api.input import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, @@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None: raise RuntimeError(error_msg) -def validate_video_to_video_input(video: VideoInput) -> VideoInput: +def validate_video_to_video_input(video: Input.Video) -> Input.Video: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput: return _validate_and_trim_duration(video) -def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: +def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: """Extracts video dimensions with error handling.""" try: return video.get_dimensions() @@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None: raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def _validate_and_trim_duration(video: VideoInput) -> VideoInput: +def _validate_and_trim_duration(video: Input.Video) -> Input.Video: """Validates video duration and trims to 5 seconds if needed.""" duration = video.get_duration() _validate_minimum_duration(duration) @@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None: raise ValueError("Input video must be at least 5 seconds long.") -def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: +def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: """Trims video to 5 seconds if longer.""" if duration > 5: return trim_video(video, 5) @@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str, resolution: str, @@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): prompt: str, negative_prompt: str, seed: int, - video: Optional[VideoInput] = None, + video: Input.Video | None = None, control_type: str = "Motion Transfer", - motion_intensity: Optional[int] = 100, + motion_intensity: int | None = 100, steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 51148211b..acd88c391 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 2fdafbbfe..3c55039c9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,12 +11,11 @@ User Guides: """ -from typing import Union, Optional -from typing_extensions import override from enum import Enum -import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, @@ -44,8 +43,6 @@ from comfy_api_nodes.util import ( sync_op, poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum): field_1280_768 = "1280:768" -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the video URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N def extract_progress_from_task_status( response: TaskStatusResponse, -) -> Union[float, None]: +) -> float | None: if hasattr(response, "progress") and response.progress is not None: return response.progress * 100 return None -def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" return await poll_op( @@ -119,8 +116,8 @@ async def get_response( async def generate_video( cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - estimated_duration: Optional[int] = None, -) -> VideoFromFile: + estimated_duration: int | None = None, +) -> InputImpl.VideoFromFile: initial_response = await sync_op( cls, endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), @@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, - end_frame: torch.Tensor, + start_frame: Input.Image, + end_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode): cls, prompt: str, ratio: str, - reference_image: Optional[torch.Tensor] = None, + reference_image: Input.Image | None = None, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index a54dc13ab..e165b8380 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,11 +1,9 @@ import base64 from io import BytesIO -import torch from typing_extensions import override -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, @@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Check if video is provided as base64 or URL if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if hasattr(video, "gcsUri") and video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) @@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): aspect_ratio: str, duration: int, seed: int, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, model: str, generate_audio: bool, ): @@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): if response.videos: video = response.videos[0] if video.bytesBase64Encoded: - return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) if video.gcsUri: return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) raise Exception("Video returned but no data or URL was provided") diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 812301fb7..c7916443c 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode): step=0.01, tooltip="Duration of the empty audio clip in seconds", ), - IO.Float.Input( + IO.Int.Input( "sample_rate", default=44100, tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, ), - IO.Float.Input( + IO.Int.Input( "channels", default=2, min=1, diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 1c3d9e697..3799a9004 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode): io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode): context_overlap=context_overlap, context_stride=context_stride, closed_loop=closed_loop, - dim=dim) + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) return io.NodeOutput(model) class WanContextWindowsManualNode(ContextWindowsManualNode): @@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ] return schema @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index e3ac58447..3429b731e 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -2,6 +2,8 @@ import torch import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO def Fourier_filter(x, threshold, scale): # FFT @@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale): return x_filtered.to(x.dtype) -class FreeU: +class FreeU(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -59,23 +66,31 @@ class FreeU: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -class FreeU_V2: + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -105,9 +120,19 @@ class FreeU_V2: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreeU": FreeU, - "FreeU_V2": FreeU_V2, -} + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py new file mode 100644 index 000000000..9cb234be1 --- /dev/null +++ b/comfy_extras/nodes_kandinsky5.py @@ -0,0 +1,136 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class Kandinsky5ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent_out["samples"] = encoded + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentStart(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentStart", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"), + io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return io.NodeOutput(latent) + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :start_frame_count] + reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)] + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :start_frame_count] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) + + +class CLIPTextEncodeKandinsky5(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeKandinsky5", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class Kandinsky5Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Kandinsky5ImageToVideo, + NormalizeVideoLatentStart, + CLIPTextEncodeKandinsky5, + ] + +async def comfy_entrypoint() -> Kandinsky5Extension: + return Kandinsky5Extension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d2df07ff9..e439b18ef 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -4,7 +4,7 @@ import torch import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import logging def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode): return luminance * sharpened return io.NodeOutput(sharpen) +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), + io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, destination, index, source=None) -> io.NodeOutput: + if source is None: + return io.NodeOutput(destination) + dest_frames = destination["samples"].shape[2] + source_frames = source["samples"].shape[2] + if index < 0: + index = dest_frames + index + if index > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") + return io.NodeOutput(destination) + if index + source_frames > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") + return io.NodeOutput(destination) + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) class LatentExtension(ComfyExtension): @override @@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension): LatentApplyOperationCFG, LatentOperationTonemapReinhard, LatentOperationSharpen, + ReplaceVideoLatentFrames ] diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index a5e405008..290e6f55e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -3,11 +3,10 @@ import scipy.ndimage import torch import comfy.utils import node_helpers -import folder_paths -import random +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI import nodes -from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) @@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou destination[..., top:bottom, left:right] = source_portion + destination_portion return destination -class LatentCompositeMasked: +class LatentCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("LATENT",), - "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("LATENT",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="LatentCompositeMasked", + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) - return (output,) + return IO.NodeOutput(output) -class ImageCompositeMasked: + composite = execute # TODO: remove + + +class ImageCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("IMAGE",), - "source": ("IMAGE",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="ImageCompositeMasked", + category="image", + inputs=[ + IO.Image.Input("destination"), + IO.Image.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) - return (output,) + return IO.NodeOutput(output) -class MaskToImage: + composite = execute # TODO: remove + + +class MaskToImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskToImage", + display_name="Convert Mask to Image", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) + return IO.NodeOutput(result) -class ImageToMask: + mask_to_image = execute # TODO: remove + + +class ImageToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageToMask", + display_name="Convert Image to Mask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): + @classmethod + def execute(cls, image, channel) -> IO.NodeOutput: channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] - return (mask,) + return IO.NodeOutput(mask) -class ImageColorToMask: + image_to_mask = execute # TODO: remove + + +class ImageColorToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageColorToMask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, color): + @classmethod + def execute(cls, image, color) -> IO.NodeOutput: temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 1.0, 0).float() - return (mask,) + return IO.NodeOutput(mask) -class SolidMask: + image_to_mask = execute # TODO: remove + + +class SolidMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="SolidMask", + category="mask", + inputs=[ + IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "solid" - - def solid(self, value, width, height): + @classmethod + def execute(cls, value, width, height) -> IO.NodeOutput: out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") - return (out,) + return IO.NodeOutput(out) -class InvertMask: + solid = execute # TODO: remove + + +class InvertMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="InvertMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "invert" - - def invert(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: out = 1.0 - mask - return (out,) + return IO.NodeOutput(out) -class CropMask: + invert = execute # TODO: remove + + +class CropMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="CropMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "crop" - - def crop(self, mask, x, y, width, height): + @classmethod + def execute(cls, mask, x, y, width, height) -> IO.NodeOutput: mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] - return (out,) + return IO.NodeOutput(out) -class MaskComposite: + crop = execute # TODO: remove + + +class MaskComposite(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "destination": ("MASK",), - "source": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskComposite", + category="mask", + inputs=[ + IO.Mask.Input("destination"), + IO.Mask.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "combine" - - def combine(self, destination, source, x, y, operation): + @classmethod + def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) @@ -267,28 +277,29 @@ class MaskComposite: output = torch.clamp(output, 0.0, 1.0) - return (output,) + return IO.NodeOutput(output) -class FeatherMask: + combine = execute # TODO: remove + + +class FeatherMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="FeatherMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "feather" - - def feather(self, mask, left, top, right, bottom): + @classmethod + def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput: output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) @@ -312,26 +323,28 @@ class FeatherMask: feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate - return (output,) + return IO.NodeOutput(output) -class GrowMask: + feather = execute # TODO: remove + + +class GrowMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "tapered_corners": ("BOOLEAN", {"default": True}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="GrowMask", + display_name="Grow Mask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("tapered_corners", default=True), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "expand_mask" - - def expand_mask(self, mask, expand, tapered_corners): + @classmethod + def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput: c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], @@ -347,69 +360,74 @@ class GrowMask: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.stack(out, dim=0),) + return IO.NodeOutput(torch.stack(out, dim=0)) -class ThresholdMask: + expand_mask = execute # TODO: remove + + +class ThresholdMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ThresholdMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, mask, value): + @classmethod + def execute(cls, mask, value) -> IO.NodeOutput: mask = (mask > value).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove + # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes -class MaskPreview(nodes.SaveImage): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 4 +class MaskPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MaskPreview", + display_name="Preview Mask", + category="mask", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + IO.Mask.Input("mask"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": {"mask": ("MASK",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - FUNCTION = "execute" - CATEGORY = "mask" - - def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewMask(mask)) -NODE_CLASS_MAPPINGS = { - "LatentCompositeMasked": LatentCompositeMasked, - "ImageCompositeMasked": ImageCompositeMasked, - "MaskToImage": MaskToImage, - "ImageToMask": ImageToMask, - "ImageColorToMask": ImageColorToMask, - "SolidMask": SolidMask, - "InvertMask": InvertMask, - "CropMask": CropMask, - "MaskComposite": MaskComposite, - "FeatherMask": FeatherMask, - "GrowMask": GrowMask, - "ThresholdMask": ThresholdMask, - "MaskPreview": MaskPreview -} +class MaskExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LatentCompositeMasked, + ImageCompositeMasked, + MaskToImage, + ImageToMask, + ImageColorToMask, + SolidMask, + InvertMask, + CropMask, + MaskComposite, + FeatherMask, + GrowMask, + ThresholdMask, + MaskPreview, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} + +async def comfy_entrypoint() -> MaskExtension: + return MaskExtension() diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f7ca9699d..dec2ae841 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode): return io.NodeOutput(m) -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "PatchModelAddDownscale": "", -} - class ModelDownscaleExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 6cf6e39bf..c609e03da 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -8,10 +8,7 @@ import json from typing import Optional from typing_extensions import override from fractions import Fraction -from comfy_api.input import AudioInput, ImageInput, VideoInput -from comfy_api.input_impl import VideoFromComponents, VideoFromFile -from comfy_api.util import VideoCodec, VideoComponents, VideoContainer -from comfy_api.latest import ComfyExtension, io, ui +from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args class SaveWEBM(io.ComfyNode): @@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode): io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode): inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), - io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), - io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput: + def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode): metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=VideoContainer(format), + format=Types.VideoContainer(format), codec=codec, metadata=saved_metadata ) @@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: return io.NodeOutput( - VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) ) class GetVideoComponents(io.ComfyNode): @@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput) -> io.NodeOutput: + def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) + class LoadVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode): @classmethod def execute(cls, file) -> io.NodeOutput: video_path = folder_paths.get_annotated_filepath(file) - return io.NodeOutput(VideoFromFile(video_path)) + return io.NodeOutput(InputImpl.VideoFromFile(video_path)) @classmethod def fingerprint_inputs(s, file): diff --git a/comfyui_version.py b/comfyui_version.py index 4b039356e..2f083edaf 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.76" +__version__ = "0.4.0" diff --git a/cuda_malloc.py b/cuda_malloc.py index 6520d5123..ee2bc4b69 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -63,18 +63,22 @@ def cuda_malloc_supported(): return True +version = "" + +try: + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ +except: + pass + if not args.cuda_malloc: try: - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc args.cuda_malloc = cuda_malloc_supported() @@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + +def get_torch_version_noimport(): + return str(version) diff --git a/main.py b/main.py index 0cd815d9e..0d02a087b 100644 --- a/main.py +++ b/main.py @@ -167,6 +167,9 @@ if __name__ == "__main__": os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") diff --git a/manager_requirements.txt b/manager_requirements.txt index 52cc5389c..b95cefb74 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b3 +comfyui_manager==4.0.3b4 diff --git a/nodes.py b/nodes.py index 356aa63df..8d28a725d 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes(): "nodes_rope.py", "nodes_logic.py", "nodes_nop.py", + "nodes_kandinsky5.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 02b94a0ce..e4d3d616a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.76" +version = "0.4.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index f98848e20..9e9b25328 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.33.10 -comfyui-workflow-templates==0.7.25 +comfyui-frontend-package==1.34.8 +comfyui-workflow-templates==0.7.54 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 63361309f..3a54941e6 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -2,6 +2,7 @@ import unittest import torch import sys import os +import json # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -15,6 +16,7 @@ if not has_gpu(): from comfy import ops from comfy.quant_ops import QuantizedTensor +import comfy.utils class SimpleModel(torch.nn.Module): @@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) + with torch.inference_mode(): + output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) @@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False)