diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py index f7336d7fd..22bd35ad4 100644 --- a/comfy/client/client_types.py +++ b/comfy/client/client_types.py @@ -1,7 +1,7 @@ import dataclasses from typing import List -from typing_extensions import TypedDict, Literal, NotRequired, Dict +from typing_extensions import TypedDict, Literal, NotRequired class FileOutput(TypedDict, total=False): @@ -21,4 +21,4 @@ class Output(TypedDict, total=False): @dataclasses.dataclass class V1QueuePromptResponse: urls: List[str] - outputs: Dict[str, Output] + outputs: dict[str, Output] diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 3821bdb60..b80705c05 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -500,19 +500,21 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, logging.error("An error occurred while executing a workflow", exc_info=ex) logging.error(traceback.format_exc()) + tips = "" + + if isinstance(ex, model_management.OOM_EXCEPTION): + tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." + logging.error("Got an OOM, unloading all loaded models.") + model_management.unload_all_models() error_details: RecursiveExecutionErrorDetails = { "node_id": real_node_id, - "exception_message": str(ex), + "exception_message": "{}\n{}".format(ex, tips), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted } - if isinstance(ex, model_management.OOM_EXCEPTION): - logging.error("Got an OOM, unloading all loaded models.") - model_management.unload_all_models() - if should_panic_on_exception(ex, args.panic_when): logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 6ef0b1f8b..a0786e93f 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -82,7 +82,12 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS current_time = time.perf_counter() execution_time = current_time - execution_start_time - logger.debug("Prompt executed in {:.2f} seconds".format(execution_time)) + # Log Time in a more readable way after 10 minutes + if execution_time > 600: + execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) + logger.info(f"Prompt executed in {execution_time}") + else: + logger.info("Prompt executed in {:.2f} seconds".format(execution_time)) flags = q.get_flags() free_memory = flags.get("free_memory", False) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 11207354f..31e8b9a0a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -726,38 +726,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: - # Euler method - d = to_d(x, sigmas[i], denoised) - dt = sigmas[i + 1] - sigmas[i] - x = x + d * dt + # Denoising step + x = denoised else: # DPM-Solver++ - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) - h = t_next - t - s = t + h * r + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) + sigma_s_1 = sigma_fn(lambda_s_1) + + alpha_s = sigmas[i] * lambda_s.exp() + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + # Step 1 - sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) - s_ = t_fn(sd) - x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised - x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su - denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta) + lambda_s_1_ = sd.log().neg() + h_ = lambda_s_1_ - lambda_s + x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised + if eta > 0 and s_noise > 0: + x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) - t_next_ = t_fn(sd) + sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta) + lambda_t_ = sd.log().neg() + h_ = lambda_t_ - lambda_s denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d - x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d + if eta > 0 and s_noise > 0: + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su return x diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 2395398e5..13ebe3a86 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -126,6 +126,8 @@ class ControlNetFlux(Flux): if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + else: + y = y[:, :self.params.vec_in_dim] # running on sequences img img = self.img_in(img) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index db6d3fb88..50a5933cc 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -117,7 +117,7 @@ class Modulation(nn.Module): def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): if modulation_dims is None: if m_add is not None: - return tensor * m_mult + m_add + return torch.addcmul(m_add, tensor, m_mult) else: return tensor * m_mult else: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index c26c21693..986260494 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -31,7 +31,7 @@ def dynamic_slice( starts: List[int], sizes: List[int], ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes)) return x[slicing] class AttnChunk(NamedTuple): diff --git a/comfy/model_base.py b/comfy/model_base.py index e012d0dfc..e5ab48e03 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1067,6 +1067,8 @@ class CosmosPredict2(BaseModel): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep + if denoise_mask.ndim <= 4: + return timestep condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) diff --git a/comfy_config/config_parser.py b/comfy_config/config_parser.py index a9cbd94dd..8da7bd901 100644 --- a/comfy_config/config_parser.py +++ b/comfy_config/config_parser.py @@ -11,6 +11,43 @@ from comfy_config.types import ( PyProjectSettings ) +def validate_and_extract_os_classifiers(classifiers: list) -> list: + os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")] + if not os_classifiers: + return [] + + os_values = [c[len("Operating System :: ") :] for c in os_classifiers] + valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"} + + for os_value in os_values: + if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes): + return [] + + return os_values + + +def validate_and_extract_accelerator_classifiers(classifiers: list) -> list: + accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")] + if not accelerator_classifiers: + return [] + + accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers] + + valid_accelerators = { + "GPU :: NVIDIA CUDA", + "GPU :: AMD ROCm", + "GPU :: Intel Arc", + "NPU :: Huawei Ascend", + "GPU :: Apple Metal", + } + + for accelerator_value in accelerator_values: + if accelerator_value not in valid_accelerators: + return [] + + return accelerator_values + + """ Extract configuration from a custom node directory's pyproject.toml file or a Python file. @@ -78,6 +115,24 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]: tool_data = raw_settings.tool comfy_data = tool_data.get("comfy", {}) if tool_data else {} + dependencies = project_data.get("dependencies", []) + supported_comfyui_frontend_version = "" + for dep in dependencies: + if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"): + supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package") + break + + supported_comfyui_version = comfy_data.get("requires-comfyui", "") + + classifiers = project_data.get('classifiers', []) + supported_os = validate_and_extract_os_classifiers(classifiers) + supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers) + + project_data['supported_os'] = supported_os + project_data['supported_accelerators'] = supported_accelerators + project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version + project_data['supported_comfyui_version'] = supported_comfyui_version + return PyProjectConfig(project=project_data, tool_comfy=comfy_data) diff --git a/comfy_config/types.py b/comfy_config/types.py index 5222cc59b..59448466b 100644 --- a/comfy_config/types.py +++ b/comfy_config/types.py @@ -51,7 +51,7 @@ class ComfyConfig(BaseModel): models: List[Model] = Field(default_factory=list, alias="Models") includes: List[str] = Field(default_factory=list) web: Optional[str] = None - + banner_url: str = "" class License(BaseModel): file: str = "" @@ -66,6 +66,10 @@ class ProjectConfig(BaseModel): dependencies: List[str] = Field(default_factory=list) license: License = Field(default_factory=License) urls: URLs = Field(default_factory=URLs) + supported_os: List[str] = Field(default_factory=list) + supported_accelerators: List[str] = Field(default_factory=list) + supported_comfyui_version: str = "" + supported_comfyui_frontend_version: str = "" @field_validator('license', mode='before') @classmethod diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index ff3f7a901..3661f34c1 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -18,6 +18,7 @@ from comfy.execution_context import current_execution_context from comfy.nodes.base_nodes import ImageScale from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.package_typing import CustomNode +from comfy.utils import common_upscale from comfy_extras.constants.resolutions import RESOLUTION_MAP, RESOLUTION_NAMES, SD_RESOLUTIONS @@ -450,86 +451,235 @@ class ImageStitch: "image1": ("IMAGE",), "direction": (["right", "down", "left", "up"], {"default": "right"}), "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ("INT", {"default": 0, "min": 0, "max": 1024, "step": 2},), - "spacing_color": (["white", "black", "red", "green", "blue"], {"default": "white"},), + "spacing_width": ( + "INT", + {"default": 0, "min": 0, "max": 1024, "step": 2}, + ), + "spacing_color": ( + ["white", "black", "red", "green", "blue"], + {"default": "white"}, + ), + }, + "optional": { + "image2": ("IMAGE",), }, - "optional": {"image2": ("IMAGE",), }, } RETURN_TYPES = ("IMAGE",) FUNCTION = "stitch" CATEGORY = "image/transform" - DESCRIPTION = "Stitches image2 to image1 in the specified direction." + DESCRIPTION = """ +Stitches image2 to image1 in the specified direction. +If image2 is not provided, returns image1 unchanged. +Optional spacing can be added between images. +""" - def stitch(self, image1, direction, match_image_size, spacing_width, spacing_color, image2=None): + def stitch( + self, + image1, + direction, + match_image_size, + spacing_width, + spacing_color, + image2=None, + ): if image2 is None: return (image1,) + # Handle batch size differences if image1.shape[0] != image2.shape[0]: max_batch = max(image1.shape[0], image2.shape[0]) if image1.shape[0] < max_batch: - image1 = torch.cat([image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]) + image1 = torch.cat( + [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] + ) if image2.shape[0] < max_batch: - image2 = torch.cat([image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]) + image2 = torch.cat( + [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] + ) + # Match image sizes if requested if match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] aspect_ratio = w2 / h2 + if direction in ["left", "right"]: target_h, target_w = h1, int(h1 * aspect_ratio) - else: + else: # up, down target_w, target_h = w1, int(w1 / aspect_ratio) - image2 = utils.common_upscale(image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled").movedim(1, -1) - else: + + image2 = common_upscale( + image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" + ).movedim(1, -1) + + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + + color_val = color_map[spacing_color] + + # When not matching sizes, pad to align non-concat dimensions + if not match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] + pad_value = 0.0 + if not isinstance(color_val, tuple): + pad_value = color_val + if direction in ["left", "right"]: + # For horizontal concat, pad heights to match if h1 != h2: target_h = max(h1, h2) if h1 < target_h: pad_h = target_h - h1 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) if h2 < target_h: pad_h = target_h - h2 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) - else: + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) + else: # up, down + # For vertical concat, pad widths to match if w1 != w2: target_w = max(w1, w2) if w1 < target_w: pad_w = target_w - w1 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) if w2 < target_w: pad_w = target_w - w2 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) - images_to_stitch = [image2, image1] if direction in ["left", "up"] else [image1, image2] + # Ensure same number of channels + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat( + [ + image1, + torch.ones( + *image1.shape[:-1], + max_channels - image1.shape[-1], + device=image1.device, + ), + ], + dim=-1, + ) + if image2.shape[-1] < max_channels: + image2 = torch.cat( + [ + image2, + torch.ones( + *image2.shape[:-1], + max_channels - image2.shape[-1], + device=image2.device, + ), + ], + dim=-1, + ) + + # Add spacing if specified if spacing_width > 0: - color_map = {"white": 1.0, "black": 0.0, "red": (1.0, 0.0, 0.0), "green": (0.0, 1.0, 0.0), "blue": (0.0, 0.0, 1.0), } - color_val = color_map[spacing_color] + spacing_width = spacing_width + (spacing_width % 2) # Ensure even + if direction in ["left", "right"]: - spacing_shape = (image1.shape[0], max(image1.shape[1], image2.shape[1]), spacing_width, image1.shape[-1],) + spacing_shape = ( + image1.shape[0], + max(image1.shape[1], image2.shape[1]), + spacing_width, + image1.shape[-1], + ) else: - spacing_shape = (image1.shape[0], spacing_width, max(image1.shape[2], image2.shape[2]), image1.shape[-1],) + spacing_shape = ( + image1.shape[0], + spacing_width, + max(image1.shape[2], image2.shape[2]), + image1.shape[-1], + ) + spacing = torch.full(spacing_shape, 0.0, device=image1.device) if isinstance(color_val, tuple): for i, c in enumerate(color_val): if i < spacing.shape[-1]: spacing[..., i] = c - if spacing.shape[-1] == 4: + if spacing.shape[-1] == 4: # Add alpha spacing[..., 3] = 1.0 else: spacing[..., : min(3, spacing.shape[-1])] = color_val if spacing.shape[-1] == 4: spacing[..., 3] = 1.0 - images_to_stitch.insert(1, spacing) + + # Concatenate images + images = [image2, image1] if direction in ["left", "up"] else [image1, image2] + if spacing_width > 0: + images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images_to_stitch, dim=concat_dim),) + return (torch.cat(images, dim=concat_dim),) + + +class ResizeAndPadImage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "target_width": ("INT", { + "default": 512, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1 + }), + "target_height": ("INT", { + "default": 512, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1 + }), + "padding_color": (["white", "black"],), + "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "resize_and_pad" + CATEGORY = "image/transform" + + def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + batch_size, orig_height, orig_width, channels = image.shape + + scale_w = target_width / orig_width + scale_h = target_height / orig_height + scale = min(scale_w, scale_h) + + new_width = int(orig_width * scale) + new_height = int(orig_height * scale) + + image_permuted = image.permute(0, 3, 1, 2) + + resized = common_upscale(image_permuted, new_width, new_height, interpolation, "disabled") + + pad_value = 0.0 if padding_color == "black" else 1.0 + padded = torch.full( + (batch_size, channels, target_height, target_width), + pad_value, + dtype=image.dtype, + device=image.device + ) + + y_offset = (target_height - new_height) // 2 + x_offset = (target_width - new_width) // 2 + + padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized + + output = padded.permute(0, 2, 3, 1) + return (output,) class SaveSVGNode: @@ -599,6 +749,7 @@ NODE_CLASS_MAPPINGS = { "SaveSVGNode": SaveSVGNode, "ImageStitch": ImageStitch, "GetImageSize": GetImageSize, + "ResizeAndPadImage": ResizeAndPadImage, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes/nodes_model_merging_model_specific.py b/comfy_extras/nodes/nodes_model_merging_model_specific.py index 6ba68a34c..0c81367fa 100644 --- a/comfy_extras/nodes/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes/nodes_model_merging_model_specific.py @@ -276,6 +276,52 @@ class ModelMergeWAN2_1(nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeCosmosPredict2_2B(nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embedder."] = argument + arg_dict["x_embedder."] = argument + arg_dict["t_embedder."] = argument + arg_dict["t_embedding_norm."] = argument + + + for i in range(28): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + +class ModelMergeCosmosPredict2_14B(nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embedder."] = argument + arg_dict["x_embedder."] = argument + arg_dict["t_embedder."] = argument + arg_dict["t_embedding_norm."] = argument + + + for i in range(36): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, # SD1 and SD2 have the same blocks @@ -289,4 +335,6 @@ NODE_CLASS_MAPPINGS = { "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, "ModelMergeWAN2_1": ModelMergeWAN2_1, + "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, + "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, }