diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 1660ec8e3..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -5,6 +5,7 @@ on: push: branches: - master + - release/** paths-ignore: - 'app/**' - 'input/**' diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml index 00ef07ebf..9012633d8 100644 --- a/.github/workflows/test-execution.yml +++ b/.github/workflows/test-execution.yml @@ -2,9 +2,9 @@ name: Execution Tests on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index 1735fd83b..fd70aff23 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -2,9 +2,9 @@ name: Test server launches without errors on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 00caf5b8a..d05179cd3 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -2,9 +2,9 @@ name: Unit Tests on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml index d9d488974..c2343cc39 100644 --- a/.github/workflows/update-version.yml +++ b/.github/workflows/update-version.yml @@ -6,6 +6,7 @@ on: - "pyproject.toml" branches: - master + - release/** jobs: update-version: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 753c66afa..1ba9edad7 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1618,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non x = x + sde_noise * sigmas[i + 1] * s_noise return x +@torch.no_grad() +def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"): + """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type) + + +@torch.no_grad() +def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"): + """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type) + @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): @@ -1765,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Predictor if sigmas[i + 1] == 0: # Denoising step - x = denoised + x_pred = denoised else: tau_t = tau_func(sigmas[i + 1]) curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] @@ -1786,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if tau_t > 0 and s_noise > 0: noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise x_pred = x_pred + noise - return x + return x_pred @torch.no_grad() diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 96cb37fa6..5628e2ba3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -634,8 +634,11 @@ class NextDiT(nn.Module): img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" img_input = img for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..902af30ed 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module): operations=operations, ) - def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _apply_gate(self, x, y, gate, timestep_zero_index=None): + if timestep_zero_index is not None: + return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1) + else: + return torch.addcmul(y, gate, x) + + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) - return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) + if timestep_zero_index is not None: + actual_batch = shift.size(0) // 2 + shift, shift_0 = shift[:actual_batch], shift[actual_batch:] + scale, scale_0 = scale[:actual_batch], scale[actual_batch:] + gate, gate_0 = gate[:actual_batch], gate[actual_batch:] + reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1)) + zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1)) + return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1)) + else: + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_zero_index=None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) + + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + txt_mod_params = self.txt_mod(temb) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index) del img_mod1 txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) del txt_mod1 @@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module): del img_modulated del txt_modulated - hidden_states = hidden_states + img_gate1 * img_attn_output + hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index) encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output del img_attn_output del txt_attn_output del img_gate1 del txt_gate1 - img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) - hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index) + hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) @@ -302,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module): pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + default_ref_method="index", image_model=None, final_layer=True, dtype=None, @@ -314,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module): self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.default_ref_method = default_ref_method self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) @@ -341,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) + if self.default_ref_method == "index_timestep_zero": + self.register_buffer("__index_timestep_zero__", torch.tensor([])) + if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) @@ -391,11 +416,14 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + timestep_zero_index = None if ref_latents is not None: h = 0 w = 0 index = 0 - index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + ref_method = kwargs.get("ref_latents_method", self.default_ref_method) + index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 @@ -415,6 +443,10 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = num_embeds txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) @@ -446,7 +478,7 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] @@ -458,6 +490,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, ) @@ -474,6 +507,9 @@ class QwenImageTransformer2DModel(nn.Module): if add is not None: hidden_states[:, :add.shape[1]] += add + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a9d5e10d9..4216ce831 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -568,7 +568,10 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -763,7 +766,10 @@ class VaceWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -862,7 +868,10 @@ class CameraWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options) if audio_emb is not None: x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) # head @@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..84d7adec4 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -523,7 +523,10 @@ class AnimateWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dd6a703f6..7148c77fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 - if "__x0__" in state_dict_keys: # x0 pred + if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred dit_config["use_x0"] = True else: dit_config["use_x0"] = False @@ -618,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "qwen_image" dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 + dit_config["default_ref_method"] = "index_timestep_zero" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index e46971afb..9134e6d71 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/samplers.py b/comfy/samplers.py index fa4640842..8340d376c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -720,7 +720,7 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 834dfcffc..1888f35ba 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -28,6 +28,7 @@ from . import supported_models_base from . import latent_formats from . import diffusers_convert +import comfy.model_management class SD15(supported_models_base.BASE): unet_config = { @@ -1028,7 +1029,13 @@ class ZImage(Lumina2): memory_usage_factor = 2.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + if comfy.model_management.extended_fp16_support(): + self.supported_inference_dtypes = self.supported_inference_dtypes.copy() + self.supported_inference_dtypes.insert(1, torch.float16) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 2b634d172..4b14e5ded 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1556,12 +1556,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) return type_clone @final diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py new file mode 100644 index 000000000..ae5bb2673 --- /dev/null +++ b/comfy_api_nodes/apis/openai_api.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = None + text_tokens: int | None = None + + +class Usage(BaseModel): + input_tokens: int | None = None + input_tokens_details: InputTokensDetails | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = None + usage: Usage | None = None + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 713260e2a..ffaaa7dc1 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel class TripoModelVersion(str, Enum): + v3_0_20250812 = 'v3.0-20250812' v2_5_20250123 = 'v2.5-20250123' v2_0_20240919 = 'v2.0-20240919' v1_4_20240625 = 'v1.4-20240625' +class TripoGeometryQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoTextureQuality(str, Enum): standard = 'standard' detailed = 'detailed' @@ -61,14 +67,20 @@ class TripoSpec(str, Enum): class TripoAnimation(str, Enum): IDLE = "preset:idle" WALK = "preset:walk" + RUN = "preset:run" + DIVE = "preset:dive" CLIMB = "preset:climb" JUMP = "preset:jump" - RUN = "preset:run" SLASH = "preset:slash" SHOOT = "preset:shoot" HURT = "preset:hurt" FALL = "preset:fall" TURN = "preset:turn" + QUADRUPED_WALK = "preset:quadruped:walk" + HEXAPOD_WALK = "preset:hexapod:walk" + OCTOPOD_WALK = "preset:octopod:walk" + SERPENTINE_MARCH = "preset:serpentine:march" + AQUATIC_MARCH = "preset:aquatic:march" class TripoStylizeStyle(str, Enum): LEGO = "lego" @@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum): BANNED = "banned" EXPIRED = "expired" +class TripoFbxPreset(str, Enum): + BLENDER = "blender" + MIXAMO = "mixamo" + _3DSMAX = "3dsmax" + class TripoFileTokenReference(BaseModel): type: Optional[str] = Field(None, description='The type of the reference') file_token: str @@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard style: Optional[TripoStyle] = None auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') @@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') @@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') @@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel): type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') format: TripoConvertFormat = Field(..., description='The format to convert to') original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') - face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(4096, description='The size of the texture') + quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') + force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') + face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') + flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') + flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') + texture_size: Optional[int] = Field(None, description='The size of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') + pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') + scale_factor: Optional[float] = Field(None, description='The scale factor for the model') + with_animation: Optional[bool] = Field(None, description='Whether to include animations') + pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') + bake: Optional[bool] = Field(None, description='Whether to bake the model') + part_names: Optional[List[str]] = Field(None, description='The names of the parts to include') + fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') + export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') + export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') + animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') + class TripoTaskRequest(RootModel): root: Union[ diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c8da5464b..a6205a34f 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,46 +1,45 @@ -from io import BytesIO +import base64 import os from enum import Enum -from inspect import cleandoc +from io import BytesIO + import numpy as np import torch from PIL import Image -import folder_paths -import base64 -from comfy_api.latest import IO, ComfyExtension from typing_extensions import override - +import folder_paths +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( - OpenAIImageGenerationRequest, - OpenAIImageEditRequest, - OpenAIImageGenerationResponse, - OpenAICreateResponse, - OpenAIResponse, CreateModelResponseProperties, - Item, - OutputContent, - InputImageContent, Detail, - InputTextContent, - InputMessage, - InputMessageContentList, InputContent, InputFileContent, + InputImageContent, + InputMessage, + InputMessageContentList, + InputTextContent, + Item, + OpenAICreateResponse, + OpenAIResponse, + OutputContent, +) +from comfy_api_nodes.apis.openai_api import ( + OpenAIImageEditRequest, + OpenAIImageGenerationRequest, + OpenAIImageGenerationResponse, ) - from comfy_api_nodes.util import ( - downscale_image_tensor, - download_url_to_bytesio, - validate_string, - tensor_to_base64_string, ApiEndpoint, - sync_op, + download_url_to_bytesio, + downscale_image_tensor, poll_op, + sync_op, + tensor_to_base64_string, text_filepath_to_data_uri, + validate_string, ) - RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" @@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten class OpenAIDalle2(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 2 endpoint. - """ @classmethod def define_schema(cls): @@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode): node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode): class OpenAIDalle3(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 3 endpoint. - """ @classmethod def define_schema(cls): @@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode): node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode): return IO.NodeOutput(await validate_and_cast_response(response)) +def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None: + # https://platform.openai.com/docs/pricing + return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0 + + +def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's GPT Image 1 endpoint. - """ @classmethod def define_schema(cls): @@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode): node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 1", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", inputs=[ IO.String.Input( "prompt", default="", multiline=True, - tooltip="Text prompt for GPT Image 1", + tooltip="Text prompt for GPT Image", ), IO.Int.Input( "seed", @@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "background", - default="opaque", - options=["opaque", "transparent"], + default="auto", + options=["auto", "opaque", "transparent"], tooltip="Return image with or without background", optional=True, ), @@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode): tooltip="Optional mask for inpainting (white areas will be replaced)", optional=True, ), + IO.Combo.Input( + "model", + options=["gpt-image-1", "gpt-image-1.5"], + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode): @classmethod async def execute( cls, - prompt, - seed=0, - quality="low", - background="opaque", - image=None, - mask=None, - n=1, - size="1024x1024", + prompt: str, + seed: int = 0, + quality: str = "low", + background: str = "opaque", + image: Input.Image | None = None, + mask: Input.Image | None = None, + n: int = 1, + size: str = "1024x1024", + model: str = "gpt-image-1", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - model = "gpt-image-1" - path = "/proxy/openai/images/generations" - content_type = "application/json" - request_class = OpenAIImageGenerationRequest - files = [] + + if mask is not None and image is None: + raise ValueError("Cannot use a mask without an input image") + + if model == "gpt-image-1": + price_extractor = calculate_tokens_price_image_1 + elif model == "gpt-image-1.5": + price_extractor = calculate_tokens_price_image_1_5 + else: + raise ValueError(f"Unknown model: {model}") if image is not None: - path = "/proxy/openai/images/edits" - request_class = OpenAIImageEditRequest - content_type = "multipart/form-data" - + files = [] batch_size = image.shape[0] - for i in range(batch_size): - single_image = image[i : i + 1] - scaled_image = downscale_image_tensor(single_image).squeeze() + single_image = image[i: i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode): else: files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) - if mask is not None: - if image is None: - raise Exception("Cannot use a mask without an input image") - if image.shape[0] != 1: - raise Exception("Cannot use a mask with multiple image") - if mask.shape[1:] != image.shape[1:-1]: - raise Exception("Mask and Image must be the same size") - batch, height, width = mask.shape - rgba_mask = torch.zeros(height, width, 4, device="cpu") - rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() + if mask is not None: + if image.shape[0] != 1: + raise Exception("Cannot use a mask with multiple image") + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + _, height, width = mask.shape + rgba_mask = torch.zeros(height, width, 4, device="cpu") + rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() - mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) - mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = BytesIO() - mask_img.save(mask_img_byte_arr, format="PNG") - mask_img_byte_arr.seek(0) - files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) - - # Build the operation - response = await sync_op( - cls, - ApiEndpoint(path=path, method="POST"), - response_model=OpenAIImageGenerationResponse, - data=request_class( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - ), - files=files if files else None, - content_type=content_type, - ) + mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_img_byte_arr = BytesIO() + mask_img.save(mask_img_byte_arr, format="PNG") + mask_img_byte_arr.seek(0) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/edits", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageEditRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + content_type="multipart/form-data", + files=files, + price_extractor=price_extractor, + ) + else: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + price_extractor=price_extractor, + ) return IO.NodeOutput(await validate_and_cast_response(response)) diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 697100ff2..bd3c24fb3 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Int.Input("model_seed", default=42, optional=True), IO.Int.Input("texture_seed", default=42, optional=True), IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed: Optional[int] = None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, ) -> IO.NodeOutput: @@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode): texture_seed=texture_seed, texture_quality=texture_quality, face_limit=face_limit, + geometry_quality=geometry_quality, auto_size=True, quad=quad, ), @@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode): ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode): orientation=None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, texture_alignment: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, @@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode): pbr=pbr, model_seed=model_seed, orientation=orientation, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, @@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ IO.String.Output(display_name="model_file"), @@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): model_seed: Optional[int] = None, texture_seed: Optional[int] = None, texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, texture_alignment: Optional[str] = None, face_limit: Optional[int] = None, quad: Optional[bool] = None, @@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, face_limit=face_limit, quad=quad, @@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode): options=[ "preset:idle", "preset:walk", + "preset:run", + "preset:dive", "preset:climb", "preset:jump", "preset:slash", @@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode): "preset:hurt", "preset:fall", "preset:turn", + "preset:quadruped:walk", + "preset:hexapod:walk", + "preset:octopod:walk", + "preset:serpentine:march", + "preset:aquatic:march" ], ), ], @@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode): "face_limit", default=-1, min=-1, - max=500000, + max=2000000, optional=True, ), IO.Int.Input( @@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode): default="JPEG", optional=True, ), + IO.Boolean.Input("force_symmetry", default=False, optional=True), + IO.Boolean.Input("flatten_bottom", default=False, optional=True), + IO.Float.Input( + "flatten_bottom_threshold", + default=0.0, + min=0.0, + max=1.0, + optional=True, + ), + IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True), + IO.Float.Input( + "scale_factor", + default=1.0, + min=0.0, + optional=True, + ), + IO.Boolean.Input("with_animation", default=False, optional=True), + IO.Boolean.Input("pack_uv", default=False, optional=True), + IO.Boolean.Input("bake", default=False, optional=True), + IO.String.Input("part_names", default="", optional=True), # comma-separated list + IO.Combo.Input( + "fbx_preset", + options=["blender", "mixamo", "3dsmax"], + default="blender", + optional=True, + ), + IO.Boolean.Input("export_vertex_colors", default=False, optional=True), + IO.Combo.Input( + "export_orientation", + options=["align_image", "default"], + default="default", + optional=True, + ), + IO.Boolean.Input("animate_in_place", default=False, optional=True), ], outputs=[], hidden=[ @@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode): original_model_task_id, format: str, quad: bool, + force_symmetry: bool, face_limit: int, + flatten_bottom: bool, + flatten_bottom_threshold: float, texture_size: int, texture_format: str, + pivot_to_center_bottom: bool, + scale_factor: float, + with_animation: bool, + pack_uv: bool, + bake: bool, + part_names: str, + fbx_preset: str, + export_vertex_colors: bool, + export_orientation: str, + animate_in_place: bool, ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") + + # Parse part_names from comma-separated string to list + part_names_list = None + if part_names and part_names.strip(): + part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] + response = await sync_op( cls, endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), @@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode): original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, + force_symmetry=force_symmetry if force_symmetry else None, face_limit=face_limit if face_limit != -1 else None, + flatten_bottom=flatten_bottom if flatten_bottom else None, + flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None, texture_size=texture_size if texture_size != 4096 else None, texture_format=texture_format if texture_format != "JPEG" else None, + pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None, + scale_factor=scale_factor if scale_factor != 1.0 else None, + with_animation=with_animation if with_animation else None, + pack_uv=pack_uv if pack_uv else None, + bake=bake if bake else None, + part_names=part_names_list, + fbx_preset=fbx_preset if fbx_preset != "blender" else None, + export_vertex_colors=export_vertex_colors if export_vertex_colors else None, + export_orientation=export_orientation if export_orientation != "default" else None, + animate_in_place=animate_in_place if animate_in_place else None, ), ) return await poll_until_finished(cls, response, average_duration=30) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 2aab3c2ff..17b680e13 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,7 +1,5 @@ import re -from typing import Optional -import torch from pydantic import BaseModel, Field from typing_extensions import override @@ -21,26 +19,26 @@ from comfy_api_nodes.util import ( class Text2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) class Image2ImageInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) images: list[str] = Field(..., min_length=1, max_length=2) class Text2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) - audio_url: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) + audio_url: str | None = Field(None) class Image2VideoInputField(BaseModel): prompt: str = Field(...) - negative_prompt: Optional[str] = Field(None) + negative_prompt: str | None = Field(None) img_url: str = Field(...) - audio_url: Optional[str] = Field(None) + audio_url: str | None = Field(None) class Txt2ImageParametersField(BaseModel): @@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel): class Image2ImageParametersField(BaseModel): - size: Optional[str] = Field(None) + size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) watermark: bool = Field(True) @@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel): class Text2VideoParametersField(BaseModel): size: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Image2VideoParametersField(BaseModel): resolution: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) - duration: int = Field(5, ge=5, le=10) + duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) watermark: bool = Field(True) - audio: bool = Field(False, description="Should be audio generated automatically") + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") class Text2ImageTaskCreationRequest(BaseModel): @@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel): class TaskCreationResponse(BaseModel): - output: Optional[TaskCreationOutputField] = Field(None) + output: TaskCreationOutputField | None = Field(None) request_id: str = Field(...) - code: Optional[str] = Field(None, description="The error code of the failed request.") - message: Optional[str] = Field(None, description="Details of the failed request.") + code: str | None = Field(None, description="Error code for the failed request.") + message: str | None = Field(None, description="Details about the failed request.") class TaskResult(BaseModel): - url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - results: Optional[list[TaskResult]] = Field(None) + results: list[TaskResult] | None = Field(None) class VideoTaskStatusOutputField(TaskCreationOutputField): task_id: str = Field(...) task_status: str = Field(...) - video_url: Optional[str] = Field(None) - code: Optional[str] = Field(None) - message: Optional[str] = Field(None) + video_url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) class ImageTaskStatusResponse(BaseModel): - output: Optional[ImageTaskStatusOutputField] = Field(None) + output: ImageTaskStatusOutputField | None = Field(None) request_id: str = Field(...) class VideoTaskStatusResponse(BaseModel): - output: Optional[VideoTaskStatusOutputField] = Field(None) + output: VideoTaskStatusOutputField | None = Field(None) request_id: str = Field(...) @@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode): node_id="WanTextToImageApi", display_name="Wan Text to Image", category="api node/image/Wan", - description="Generates image based on text prompt.", + description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( "model", @@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Int.Input( @@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Image.Input( "image", - tooltip="Single-image editing or multi-image fusion, maximum 2 images.", + tooltip="Single-image editing or multi-image fusion. Maximum 2 images.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), # redo this later as an optional combo of recommended resolutions @@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), ], @@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", # width: int = 1024, @@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode): ): n_images = get_number_of_images(image) if n_images not in (1, 2): - raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") + raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.") images = [] for i in image: images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) @@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode): ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode): node_id="WanTextToVideoApi", display_name="Wan Text to Video", category="api node/video/Wan", - description="Generates video based on text prompt.", + description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-t2v-preview"], - default="wan2.5-t2v-preview", + options=["wan2.5-t2v-preview", "wan2.6-t2v"], + default="wan2.6-t2v", tooltip="Model to use.", ), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode): "1080p: 4:3 (1632x1248)", "1080p: 3:4 (1248x1632)", ], - default="480p: 1:1 (624x624)", + default="720p: 1:1 (960x960)", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="A 15-second duration is available only for the Wan 2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode): model: str, prompt: str, negative_prompt: str = "", - size: str = "480p: 1:1 (624x624)", + size: str = "720p: 1:1 (960x960)", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): + if "480p" in size and model == "wan2.6-t2v": + raise ValueError("The Wan 2.6 model does not support 480p.") + if duration == 15 and model == "wan2.5-t2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") width, height = RES_IN_PARENS.search(size).groups() audio_url = None if audio is not None: @@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), @@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode): node_id="WanImageToVideoApi", display_name="Wan Image to Video", category="api node/video/Wan", - description="Generates video based on the first frame and text prompt.", + description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( "model", - options=["wan2.5-i2v-preview"], - default="wan2.5-i2v-preview", + options=["wan2.5-i2v-preview", "wan2.6-i2v"], + default="wan2.6-i2v", tooltip="Model to use.", ), IO.Image.Input( @@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode): "prompt", multiline=True, default="", - tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", ), IO.String.Input( "negative_prompt", multiline=True, default="", - tooltip="Negative text prompt to guide what to avoid.", + tooltip="Negative prompt describing what to avoid.", optional=True, ), IO.Combo.Input( @@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode): "720P", "1080P", ], - default="480P", + default="720P", optional=True, ), IO.Int.Input( "duration", default=5, min=5, - max=10, + max=15, step=5, display_mode=IO.NumberDisplay.number, - tooltip="Available durations: 5 and 10 seconds", + tooltip="Duration 15 available only for WAN2.6 model.", optional=True, ), IO.Audio.Input( "audio", optional=True, - tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", ), IO.Int.Input( "seed", @@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode): "generate_audio", default=False, optional=True, - tooltip="If there is no audio input, generate audio automatically.", + tooltip="If no audio input is provided, generate audio automatically.", ), IO.Boolean.Input( "prompt_extend", @@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip='Whether to add an "AI generated" watermark to the result.', + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", optional=True, ), ], @@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str = "", - resolution: str = "480P", + resolution: str = "720P", duration: int = 5, - audio: Optional[Input.Audio] = None, + audio: Input.Audio | None = None, seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, watermark: bool = True, + shot_type: str = "single", ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") + if "480P" in resolution and model == "wan2.6-i2v": + raise ValueError("The Wan 2.6 model does not support 480P.") + if duration == 15 and model == "wan2.5-i2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: @@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode): audio=generate_audio, prompt_extend=prompt_extend, watermark=watermark, + shot_type=shot_type, ), ), ) if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index c57457580..d64239c86 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: +def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: """Downscale input image tensor to roughly the specified total pixels.""" samples = image.movedim(-1, 1) total = int(total_pixels) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 71ea4e9ec..7ee4caac1 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -671,7 +671,16 @@ class SamplerSEEDS2(io.ComfyNode): io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), ], - outputs=[io.Sampler.Output()] + outputs=[io.Sampler.Output()], + description=( + "This sampler node can represent multiple samplers:\n\n" + "seeds_2\n" + "- default setting\n\n" + "exp_heun_2_x0\n" + "- solver_type=phi_2, r=1.0, eta=0.0\n\n" + "exp_heun_2_x0_sde\n" + "- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0" + ) ) @classmethod diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 4789d7d53..513aecf3a 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode): # ========== Training Dataset Nodes ========== +class ResolutionBucket(io.ComfyNode): + """Bucket latents and conditions by resolution for efficient batch training.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionBucket", + display_name="Resolution Bucket", + category="dataset", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts to bucket by resolution.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists (must match latents length).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of batched latent dicts, one per resolution bucket.", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of condition lists, one per resolution bucket.", + ), + ], + ) + + @classmethod + def execute(cls, latents, conditioning): + # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." + ) + + # Flatten latents and conditions to individual samples + flat_latents = [] # list of (C, H, W) tensors + flat_conditions = [] # list of condition lists + + for latent_dict, cond in zip(latents, conditioning): + samples = latent_dict["samples"] # (B, C, H, W) + batch_size = samples.shape[0] + + # cond is a list of conditions with length == batch_size + for i in range(batch_size): + flat_latents.append(samples[i]) # (C, H, W) + flat_conditions.append(cond[i]) # single condition + + # Group by resolution (H, W) + buckets = {} # (H, W) -> {"latents": list, "conditions": list} + + for latent, cond in zip(flat_latents, flat_conditions): + # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) + h, w = latent.shape[-2], latent.shape[-1] + key = (h, w) + + if key not in buckets: + buckets[key] = {"latents": [], "conditions": []} + + buckets[key]["latents"].append(latent) + buckets[key]["conditions"].append(cond) + + # Convert buckets to output format + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) + output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + + for (h, w), bucket_data in buckets.items(): + # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) + stacked_latents = torch.stack(bucket_data["latents"], dim=0) + output_latents.append({"samples": stacked_latents}) + + # Conditions stay as list of condition lists + output_conditions.append(bucket_data["conditions"]) + + logging.info( + f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" + ) + + logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + return io.NodeOutput(output_latents, output_conditions) + + class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" @@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f, weights_only=True) + shard_data = torch.load(f) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension): MakeTrainingDataset, SaveTrainingDataset, LoadTrainingDataset, + ResolutionBucket, ] diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index d9c4bba81..12c8ed3e6 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FluxKontextMultiReferenceLatentMethod", + display_name="Edit Model Reference Method", category="advanced/conditioning/flux", inputs=[ io.Conditioning.Input("conditioning"), io.Combo.Input( "reference_latents_method", - options=["offset", "index", "uxo/uno"], + options=["offset", "index", "uxo/uno", "index_timestep_zero"], ), ], outputs=[ diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index ec0e790dc..2a0cfcf18 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -248,7 +248,10 @@ class ModelPatchLoader: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 config['refiner_control'] = True - config['broken'] = True + ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None) + if ref_weight is not None: + if torch.count_nonzero(ref_weight) == 0: + config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) model.load_state_dict(sd) @@ -310,22 +313,46 @@ class ZImageControlPatch: self.inpaint_image = inpaint_image self.mask = mask self.strength = strength - self.encoded_image = self.encode_latent_cond(image) - self.encoded_image_size = (image.shape[1], image.shape[2]) + self.is_inpaint = self.model_patch.model.additional_in_dim > 0 + + skip_encoding = False + if self.image is not None and self.inpaint_image is not None: + if self.image.shape != self.inpaint_image.shape: + skip_encoding = True + + if skip_encoding: + self.encoded_image = None + else: + self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image) + if self.image is None: + self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2]) + else: + self.encoded_image_size = (self.image.shape[1], self.image.shape[2]) self.temp_data = None - def encode_latent_cond(self, control_image, inpaint_image=None): - latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) - if self.model_patch.model.additional_in_dim > 0: - if self.mask is None: - mask_ = torch.zeros_like(latent_image)[:, :1] - else: - mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + def encode_latent_cond(self, control_image=None, inpaint_image=None): + latent_image = None + if control_image is not None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) + + if self.is_inpaint: if inpaint_image is None: inpaint_image = torch.ones_like(control_image) * 0.5 + if self.mask is not None: + mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center") + inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5 + inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image)) + if self.mask is None: + mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] + else: + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + + if latent_image is None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) + return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1) else: return latent_image @@ -341,13 +368,18 @@ class ZImageControlPatch: block_type = kwargs.get("block_type", "") spacial_compression = self.vae.spacial_compression_encode() if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): - image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + image_scaled = None + if self.image is not None: + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2]) + inpaint_scaled = None if self.inpaint_image is not None: inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2]) + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled) - self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled) comfy.model_management.load_models_gpu(loaded_models) cnet_blocks = self.model_patch.model.n_control_layers @@ -388,7 +420,8 @@ class ZImageControlPatch: def to(self, device_or_dtype): if isinstance(device_or_dtype, torch.device): - self.encoded_image = self.encoded_image.to(device_or_dtype) + if self.encoded_image is not None: + self.encoded_image = self.encoded_image.to(device_or_dtype) self.temp_data = None return self @@ -411,9 +444,12 @@ class QwenImageDiffsynthControlnet: CATEGORY = "advanced/loaders/qwen" - def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): + def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None): model_patched = model.clone() - image = image[:, :, :, :3] + if image is not None: + image = image[:, :, :, :3] + if inpaint_image is not None: + inpaint_image = inpaint_image[:, :, :, :3] if mask is not None: if mask.ndim == 3: mask = mask.unsqueeze(1) @@ -422,13 +458,24 @@ class QwenImageDiffsynthControlnet: mask = 1.0 - mask if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): - patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask) + patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask) model_patched.set_model_noise_refiner_patch(patch) model_patched.set_model_double_block_patch(patch) else: model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) +class ZImageFunControlnet(QwenImageDiffsynthControlnet): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }, + "optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}} + + CATEGORY = "advanced/loaders/zimage" class UsoStyleProjectorPatch: def __init__(self, model_patch, encoded_image): @@ -476,5 +523,6 @@ class USOStyleReference: NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, } diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 34c388a5a..ca2cdeb50 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.Int.Input("resolution_steps", default=1, min=1, max=256), ], outputs=[ io.Image.Output(), @@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode): ) @classmethod - def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: + def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput: samples = image.movedim(-1,1) - total = int(megapixels * 1024 * 1024) + total = megapixels * 1024 * 1024 scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) + width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps + height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled") s = s.movedim(1,-1) return io.NodeOutput(s) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 19b8baaf4..364804205 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont from typing_extensions import override import comfy.samplers +import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management @@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar +class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): + """ + CFGGuider with modifications for training specific logic + """ + def outer_sample( + self, + noise, + latent_image, + sampler, + sigmas, + denoise_mask=None, + callback=None, + disable_pbar=False, + seed=None, + latent_shapes=None, + ): + self.inner_model, self.conds, self.loaded_models = ( + comfy.sampler_helpers.prepare_sampling( + self.model_patcher, + noise.shape, + self.conds, + self.model_options, + force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded + ) + ) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask( + denoise_mask, noise.shape, device + ) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + comfy.samplers.cast_to_load_options( + self.model_options, device=device, dtype=self.model_patcher.model_dtype() + ) + + try: + self.model_patcher.pre_run() + output = self.inner_sample( + noise, + latent_image, + device, + sampler, + sigmas, + denoise_mask, + callback, + disable_pbar, + seed, + latent_shapes=latent_shapes, + ) + finally: + self.model_patcher.cleanup() + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.loaded_models + return output + + def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): @@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler): seed=0, training_dtype=torch.bfloat16, real_dataset=None, + bucket_latents=None, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler): self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset + # Bucket mode data + self.bucket_latents: list[torch.Tensor] | None = ( + bucket_latents # list of (Bi, C, Hi, Wi) + ) + # Precompute bucket offsets and weights for sampling + if bucket_latents is not None: + self._init_bucket_data(bucket_latents) + else: + self.bucket_offsets = None + self.bucket_weights = None + self.num_images = None + + def _init_bucket_data(self, bucket_latents): + """Initialize bucket offsets and weights for sampling.""" + self.bucket_offsets = [0] + bucket_sizes = [] + for lat in bucket_latents: + bucket_sizes.append(lat.shape[0]) + self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) + self.num_images = self.bucket_offsets[-1] + # Weights for sampling buckets proportional to their size + self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) def fwd_bwd( self, @@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler): bwd_loss.backward() return loss + def _generate_batch_sigmas(self, model_wrap, batch_size, device): + """Generate random sigma values for a batch.""" + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(batch_size) + ] + return torch.tensor(batch_sigmas).to(device) + + def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar): + """Execute one training step in bucket mode.""" + # Sample bucket (weighted by size), then sample batch from bucket + bucket_idx = torch.multinomial(self.bucket_weights, 1).item() + bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) + bucket_size = bucket_latent.shape[0] + bucket_offset = self.bucket_offsets[bucket_idx] + + # Sample indices from this bucket (use all if bucket_size < batch_size) + actual_batch_size = min(self.batch_size, bucket_size) + relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() + # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) + absolute_indices = [bucket_offset + idx for idx in relative_indices] + + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, # Use flattened cond with absolute indices + absolute_indices, + extra_args, + self.num_images, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) + + def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in standard (non-bucket, non-multi-res) mode.""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in multi-resolution mode (real_dataset is set).""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + def sample( self, model_wrap, @@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( self.seed + i * 1000 ) - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - if self.real_dataset is None: - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( - batch_latent.device - ) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - batch_latent, - cond, - indicies, - extra_args, - dataset_size, - bwd=True, - ) - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + if self.bucket_latents is not None: + self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar) + elif self.real_dataset is None: + self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) else: - total_loss = 0 - for index in indicies: - single_latent = self.real_dataset[index].to(latent_image) - batch_noise = noisegen.generate_noise( - {"samples": single_latent} - ).to(single_latent.device) - batch_sigmas = ( - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - ) - batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - single_latent, - cond, - [index], - extra_args, - dataset_size, - bwd=False, - ) - total_loss += loss - total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() - if self.loss_callback: - self.loss_callback(total_loss.item()) - pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() - ui_pbar.update(1) + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -283,6 +419,364 @@ def unpatch(m): del m.org_forward +def _process_latents_bucket_mode(latents): + """Process latents for bucket mode training. + + Args: + latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + + Returns: + list of latent tensors + """ + bucket_latents = [] + for latent_dict in latents: + bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) + return bucket_latents + + +def _process_latents_standard_mode(latents): + """Process latents for standard (non-bucket) mode training. + + Args: + latents: list of latent dicts or single latent dict + + Returns: + Processed latents (tensor or list of tensors) + """ + if len(latents) == 1: + return latents[0]["samples"] # Single latent dict + + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + return latent_list + + +def _process_conditioning(positive): + """Process conditioning - either single list or list of lists. + + Args: + positive: list of conditioning + + Returns: + Flattened conditioning list + """ + if len(positive) == 1: + return positive[0] # Single conditioning list + + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + return flat_positive + + +def _prepare_latents_and_count(latents, dtype, bucket_mode): + """Convert latents to dtype and compute image counts. + + Args: + latents: Latents (tensor, list of tensors, or bucket list) + dtype: Target dtype + bucket_mode: Whether bucket mode is enabled + + Returns: + tuple: (processed_latents, num_images, multi_res) + """ + if bucket_mode: + # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) + latents = [t.to(dtype) for t in latents] + num_buckets = len(latents) + num_images = sum(t.shape[0] for t in latents) + multi_res = False # Not using multi_res path in bucket mode + + logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + for i, lat in enumerate(latents): + logging.info(f" Bucket {i}: shape {lat.shape}") + return latents, num_images, multi_res + + # Non-bucket mode + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + multi_res = False + else: + logging.error(f"Invalid latents type: {type(latents)}") + num_images = 0 + multi_res = False + + return latents, num_images, multi_res + + +def _validate_and_expand_conditioning(positive, num_images, bucket_mode): + """Validate conditioning count matches image count, expand if needed. + + Args: + positive: Conditioning list + num_images: Number of images + bucket_mode: Whether bucket mode is enabled + + Returns: + Validated/expanded conditioning list + + Raises: + ValueError: If conditioning count doesn't match image count + """ + if bucket_mode: + return positive # Skip validation in bucket mode + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + return positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) + return positive + + +def _load_existing_lora(existing_lora): + """Load existing LoRA weights if provided. + + Args: + existing_lora: LoRA filename or "[None]" + + Returns: + tuple: (existing_weights dict, existing_steps int) + """ + if existing_lora == "[None]": + return {}, 0 + + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + existing_weights = {} + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + return existing_weights, existing_steps + + +def _create_weight_adapter( + module, module_name, existing_weights, algorithm, lora_dtype, rank +): + """Create a weight adapter for a module with weight. + + Args: + module: The module to create adapter for + module_name: Name of the module + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (train_adapter, lora_params dict) + """ + key = f"{module_name}.weight" + shape = module.weight.shape + lora_params = {} + + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) + + # Try to load existing adapter + existing_adapter = None + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + module_name, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + + if existing_adapter is None: + adapter_cls = adapter_maps[algorithm] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + module.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + + for name, parameter in train_adapter.named_parameters(): + lora_params[f"{module_name}.{name}"] = parameter + + return train_adapter.train().requires_grad_(True), lora_params + else: + # 1D weight - use BiasDiff + diff = torch.nn.Parameter( + torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) + ) + diff_module = BiasDiff(diff).train().requires_grad_(True) + lora_params[f"{module_name}.diff"] = diff + return diff_module, lora_params + + +def _create_bias_adapter(module, module_name, lora_dtype): + """Create a bias adapter for a module with bias. + + Args: + module: The module with bias + module_name: Name of the module + lora_dtype: dtype for LoRA weights + + Returns: + tuple: (bias_module, lora_params dict) + """ + bias = torch.nn.Parameter( + torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias).train().requires_grad_(True) + lora_params = {f"{module_name}.diff_b": bias} + return bias_module, lora_params + + +def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup all LoRA adapters on the model. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list) + """ + lora_sd = {} + all_weight_adapters = [] + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + key = f"{n}.weight" + mp.add_weight_wrapper(key, adapter) + all_weight_adapters.append(adapter) + + if hasattr(m, "bias") and m.bias is not None: + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + + return lora_sd, all_weight_adapters + + +def _create_optimizer(optimizer_name, parameters, learning_rate): + """Create optimizer based on name. + + Args: + optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop") + parameters: Parameters to optimize + learning_rate: Learning rate + + Returns: + Optimizer instance + """ + if optimizer_name == "Adam": + return torch.optim.Adam(parameters, lr=learning_rate) + elif optimizer_name == "AdamW": + return torch.optim.AdamW(parameters, lr=learning_rate) + elif optimizer_name == "SGD": + return torch.optim.SGD(parameters, lr=learning_rate) + elif optimizer_name == "RMSprop": + return torch.optim.RMSprop(parameters, lr=learning_rate) + + +def _create_loss_function(loss_function_name): + """Create loss function based on name. + + Args: + loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1") + + Returns: + Loss function instance + """ + if loss_function_name == "MSE": + return torch.nn.MSELoss() + elif loss_function_name == "L1": + return torch.nn.L1Loss() + elif loss_function_name == "Huber": + return torch.nn.HuberLoss() + elif loss_function_name == "SmoothL1": + return torch.nn.SmoothL1Loss() + + +def _run_training_loop( + guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res +): + """Execute the training loop. + + Args: + guider: The guider object + train_sampler: The training sampler + latents: Latent tensors + num_images: Number of images + seed: Random seed + bucket_mode: Whether bucket mode is enabled + multi_res: Whether multi-resolution mode is enabled + """ + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + + if bucket_mode: + # Use first bucket's first latent as dummy for guider + dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": dummy_latent}), + dummy_latent, + train_sampler, + sigmas, + seed=noise.seed, + ) + elif multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + else: + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + + class TrainLoraNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode): default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), + io.Boolean.Input( + "bucket_mode", + default=False, + tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", + ), ], outputs=[ io.Model.Output( @@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode): algorithm, gradient_checkpointing, existing_lora, + bucket_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode): grad_accumulation_steps = grad_accumulation_steps[0] learning_rate = learning_rate[0] rank = rank[0] - optimizer = optimizer[0] - loss_function = loss_function[0] + optimizer_name = optimizer[0] + loss_function_name = loss_function[0] seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] + bucket_mode = bucket_mode[0] - # Handle latents - either single dict or list of dicts - if len(latents) == 1: - latents = latents[0]["samples"] # Single latent dict + # Process latents based on mode + if bucket_mode: + latents = _process_latents_bucket_mode(latents) else: - latent_list = [] - for latent in latents: - latent = latent["samples"] - bs = latent.shape[0] - if bs != 1: - for sub_latent in latent: - latent_list.append(sub_latent[None]) - else: - latent_list.append(latent) - latents = latent_list + latents = _process_latents_standard_mode(latents) - # Handle conditioning - either single list or list of lists - if len(positive) == 1: - positive = positive[0] # Single conditioning list - else: - # Multiple conditioning lists - flatten - flat_positive = [] - for cond in positive: - if isinstance(cond, list): - flat_positive.extend(cond) - else: - flat_positive.append(cond) - positive = flat_positive + # Process conditioning + positive = _process_conditioning(positive) + # Setup model and dtype mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - # latents here can be list of different size latent or one large batch - if isinstance(latents, list): - all_shapes = set() - latents = [t.to(dtype) for t in latents] - for latent in latents: - all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") - if len(all_shapes) > 1: - multi_res = True - else: - multi_res = False - latents = torch.cat(latents, dim=0) - num_images = len(latents) - elif isinstance(latents, torch.Tensor): - latents = latents.to(dtype) - num_images = latents.shape[0] - else: - logging.error(f"Invalid latents type: {type(latents)}") + # Prepare latents and compute counts + latents, num_images, multi_res = _prepare_latents_and_count( + latents, dtype, bucket_mode + ) - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") - if len(positive) == 1 and num_images > 1: - positive = positive * num_images - elif len(positive) != num_images: - raise ValueError( - f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." - ) + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): - lora_sd = {} - generator = torch.Generator() - generator.manual_seed(seed) + # Setup models for training + mp.model.requires_grad_(False) # Load existing LoRA weights if provided - existing_weights = {} - existing_steps = 0 - if existing_lora != "[None]": - lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) - # Extract steps from filename like "trained_lora_10_steps_20250225_203716" - existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) - if lora_path: - existing_weights = comfy.utils.load_torch_file(lora_path) + existing_weights, existing_steps = _load_existing_lora(existing_lora) - all_weight_adapters = [] - for n, m in mp.model.named_modules(): - if hasattr(m, "weight_function"): - if m.weight is not None: - key = "{}.weight".format(n) - shape = m.weight.shape - if len(shape) >= 2: - alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get(f"{key}.dora_scale", None) - for adapter_cls in adapters: - existing_adapter = adapter_cls.load( - n, existing_weights, alpha, dora_scale - ) - if existing_adapter is not None: - break - else: - existing_adapter = None - adapter_cls = adapter_maps[algorithm] + # Setup LoRA adapters + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) - if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to( - lora_dtype - ) - else: - # Use LoRA with alpha=1.0 by default - train_adapter = adapter_cls.create_train( - m.weight, rank=rank, alpha=1.0 - ).to(lora_dtype) - for name, parameter in train_adapter.named_parameters(): - lora_sd[f"{n}.{name}"] = parameter + # Create optimizer and loss function + optimizer = _create_optimizer( + optimizer_name, lora_sd.values(), learning_rate + ) + criterion = _create_loss_function(loss_function_name) - mp.add_weight_wrapper(key, train_adapter) - all_weight_adapters.append(train_adapter) - else: - diff = torch.nn.Parameter( - torch.zeros( - m.weight.shape, dtype=lora_dtype, requires_grad=True - ) - ) - diff_module = BiasDiff(diff) - mp.add_weight_wrapper(key, BiasDiff(diff)) - all_weight_adapters.append(diff_module) - lora_sd["{}.diff".format(n)] = diff - if hasattr(m, "bias") and m.bias is not None: - key = "{}.bias".format(n) - bias = torch.nn.Parameter( - torch.zeros( - m.bias.shape, dtype=lora_dtype, requires_grad=True - ) - ) - bias_module = BiasDiff(bias) - lora_sd["{}.diff_b".format(n)] = bias - mp.add_weight_wrapper(key, BiasDiff(bias)) - all_weight_adapters.append(bias_module) - - if optimizer == "Adam": - optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) - elif optimizer == "AdamW": - optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) - elif optimizer == "SGD": - optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) - elif optimizer == "RMSprop": - optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) - - # Setup loss function based on selection - if loss_function == "MSE": - criterion = torch.nn.MSELoss() - elif loss_function == "L1": - criterion = torch.nn.L1Loss() - elif loss_function == "Huber": - criterion = torch.nn.HuberLoss() - elif loss_function == "SmoothL1": - criterion = torch.nn.SmoothL1Loss() - - # setup models + # Setup gradient checkpointing if gradient_checkpointing: for m in find_all_highest_child_module_with_forward( mp.model.diffusion_model ): patch(m) - mp.model.requires_grad_(False) + + torch.cuda.empty_cache() + # With force_full_load=False we should be able to have offloading + # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( [mp], memory_required=1e20, force_full_load=True ) + torch.cuda.empty_cache() - # Setup sampler and guider like in test script + # Setup loss tracking loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) - train_sampler = TrainSampler( - criterion, - optimizer, - loss_callback=loss_callback, - batch_size=batch_size, - grad_acc=grad_accumulation_steps, - total_steps=steps * grad_accumulation_steps, - seed=seed, - training_dtype=dtype, - real_dataset=latents if multi_res else None, - ) - guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) - guider.set_conds(positive) # Set conditioning from input + # Create sampler + if bucket_mode: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + bucket_latents=latents, + ) + else: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + real_dataset=latents if multi_res else None, + ) - # Training loop + # Setup guider + guider = TrainGuider(mp) + guider.set_conds(positive) + + # Run training loop try: - # Generate dummy sigmas and noise - sigmas = torch.tensor(range(num_images)) - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) - if multi_res: - # use first latent as dummy latent if multi_res - latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) - guider.sample( - noise.generate_noise({"samples": latents}), - latents, + _run_training_loop( + guider, train_sampler, - sigmas, - seed=noise.seed, + latents, + num_images, + seed, + bucket_mode, + multi_res, ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer + # Finalize adapters for adapter in all_weight_adapters: adapter.requires_grad_(False) @@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode): return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode): +class LoraModelLoader(io.ComfyNode):# @classmethod def define_schema(cls): return io.Schema( diff --git a/comfyui_version.py b/comfyui_version.py index 2f083edaf..b45309198 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.4.0" +__version__ = "0.5.1" diff --git a/main.py b/main.py index 0d02a087b..0e07a95da 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,38 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if os.name == "nt": + os.environ['MIMALLOC_PURGE_DELAY'] = '0' + +if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' + if args.default_device is not None: + default_dev = args.default_device + devices = list(range(32)) + devices.remove(default_dev) + devices.insert(0, default_dev) + devices = ','.join(map(str, devices)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) + os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) + + if args.oneapi_device_selector is not None: + os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector + logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) + + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + def handle_comfyui_manager_unavailable(): if not args.windows_standalone_build: @@ -137,40 +169,6 @@ import shutil import threading import gc - -if os.name == "nt": - os.environ['MIMALLOC_PURGE_DELAY'] = '0' - -if __name__ == "__main__": - os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' - if args.default_device is not None: - default_dev = args.default_device - devices = list(range(32)) - devices.remove(default_dev) - devices.insert(0, default_dev) - devices = ','.join(map(str, devices)) - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) - os.environ['HIP_VISIBLE_DEVICES'] = str(devices) - - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) - logging.info("Set cuda device to: {}".format(args.cuda_device)) - - if args.oneapi_device_selector is not None: - os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector - logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - - import cuda_malloc - if "rocm" in cuda_malloc.get_torch_version_noimport(): - os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD - - if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") diff --git a/pyproject.toml b/pyproject.toml index e4d3d616a..3a6960811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.4.0" +version = "0.5.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 117260515..9b9e61683 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.34.8 +comfyui-frontend-package==1.34.9 comfyui-workflow-templates==0.7.59 comfyui-embedded-docs==0.3.1 torch