diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml index 6fceb7560..737e4c488 100644 --- a/.github/workflows/release-webhook.yml +++ b/.github/workflows/release-webhook.yml @@ -7,6 +7,8 @@ on: jobs: send-webhook: runs-on: ubuntu-latest + env: + DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }} steps: - name: Send release webhook env: @@ -106,3 +108,37 @@ jobs: --fail --silent --show-error echo "✅ Release webhook sent successfully" + + - name: Send repository dispatch to desktop + env: + DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} + run: | + set -euo pipefail + + if [ -z "${DISPATCH_TOKEN:-}" ]; then + echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set." + exit 1 + fi + + PAYLOAD="$(jq -n \ + --arg release_tag "$RELEASE_TAG" \ + --arg release_url "$RELEASE_URL" \ + '{ + event_type: "comfyui_release_published", + client_payload: { + release_tag: $release_tag, + release_url: $release_url + } + }')" + + curl -fsSL \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${DISPATCH_TOKEN}" \ + https://api.github.com/repos/Comfy-Org/desktop/dispatches \ + -d "$PAYLOAD" + + echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop" diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c0c51d51a..6978eb717 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,12 +1,11 @@ import math -import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange as trange_, tqdm +from tqdm.auto import tqdm from . import utils from . import deis @@ -15,34 +14,7 @@ import comfy.model_patcher import comfy.model_sampling import comfy.memory_management - - -def trange(*args, **kwargs): - if comfy.memory_management.aimdo_allocator is None: - return trange_(*args, **kwargs) - - pbar = trange_(*args, **kwargs, smoothing=1.0) - pbar._i = 0 - pbar.set_postfix_str(" Model Initializing ... ") - - _update = pbar.update - - def warmup_update(n=1): - pbar._i += 1 - if pbar._i == 1: - pbar.i1_time = time.time() - pbar.set_postfix_str(" Model Initialization complete! ") - elif pbar._i == 2: - #bring forward the effective start time based the the diff between first and second iteration - #to attempt to remove load overhead from the final step rate estimate. - pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) - pbar.set_postfix_str("") - - _update(n) - - pbar.update = warmup_update - return pbar - +from comfy.utils import model_trange as trange def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py index 2e6ed58fa..6fb51c4a4 100644 --- a/comfy/ldm/anima/model.py +++ b/comfy/ldm/anima/model.py @@ -195,8 +195,20 @@ class Anima(MiniTrainDIT): super().__init__(*args, **kwargs) self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) - def preprocess_text_embeds(self, text_embeds, text_ids): + def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): if text_ids is not None: - return self.llm_adapter(text_embeds, text_ids) + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + + if out.shape[1] < 512: + out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out else: return text_embeds + + def forward(self, x, timesteps, context, **kwargs): + t5xxl_ids = kwargs.pop("t5xxl_ids", None) + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) + return super().forward(x, timesteps, context, **kwargs) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..5e764bb46 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.to(dtype=torch.float32, device=pos.device) +def _apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + + return x_out.reshape(*x.shape).type_as(x) + + +def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + + try: import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 + q_apply_rope = comfy.quant_ops.ck.apply_rope + q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 + def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return q_apply_rope1(x, freqs_cis) except: logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + apply_rope = _apply_rope + apply_rope1 = _apply_rope1 diff --git a/comfy/model_base.py b/comfy/model_base.py index 858789b30..4a74cb1ce 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1160,12 +1160,16 @@ class Anima(BaseModel): device = kwargs["device"] if cross_attn is not None: if t5xxl_ids is not None: - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) if t5xxl_weights is not None: - cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + t5xxl_ids = t5xxl_ids.unsqueeze(0) + + if torch.is_inference_mode_enabled(): # if not we are training + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) + else: + out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) + out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) - if cross_attn.shape[1] < 512: - cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out diff --git a/comfy/model_management.py b/comfy/model_management.py index 6018c1ab6..38c3e482b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,6 +55,11 @@ cpu_state = CPUState.GPU total_vram = 0 + +# Training Related State +in_training = False + + def get_supported_float8_types(): float8_types = [] try: @@ -1208,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0] - if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + v_tensor = weight._v_tensor + else: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + weight._v_tensor = v_tensor weight._v_signature = signature #Send it over v_tensor.copy_(weight, non_blocking=non_blocking) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 19c9031ea..f278fccac 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher): setattr(m, param_key + "_function", weight_function) geometry = weight if not isinstance(weight, QuantizedTensor): - model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype weight._model_dtype = model_dtype geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) return comfy.memory_management.vram_aligned_size(geometry) @@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None and not hasattr(m, "_v"): m._v = vbar.alloc(v_weight_size) - m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to) allocated_size += v_weight_size else: @@ -1552,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher): weight.seed_key = key set_dirty(weight, dirty) geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): weight._v = vbar.alloc(weight_size) - weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to) weight._model_dtype = model_dtype allocated_size += weight_size vbar.set_watermark_limit(allocated_size) diff --git a/comfy/ops.py b/comfy/ops.py index 33803b223..688937e43 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): offload_stream = None xfer_dest = None - cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - if signature is not None: - xfer_dest = s._v_tensor resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + if signature is not None: + if resident: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if not resident: + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None xfer_source = [ s.weight, s.bias ] @@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) - weight = params[0] - bias = params[1] + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + if signature is not None: + s._v_weight = weight + s._v_bias = bias + s._v_signature=signature def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) @@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if s.bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - s._v_signature=signature #FIXME: weird offload return protocol return weight, bias, (offload_stream, device if signature is not None else None, None) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 9134e6d71..1f75f2ba7 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) + if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load + memory_required = 1e20 + minimum_memory_required = None + else: + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + memory_required += inference_memory + minimum_memory_required += inference_memory + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 73697b3c1..b8198a820 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,7 +3,6 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math -from tqdm.auto import trange import yaml import comfy.utils @@ -52,7 +51,7 @@ def sample_manual_loop_no_classes( progress_bar = comfy.utils.ProgressBar(max_new_tokens) - for step in trange(max_new_tokens, desc="LM sampling"): + for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"): outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) next_token_logits = model.transformer.logits(outputs[0])[:, -1] past_key_values = outputs[2] diff --git a/comfy/utils.py b/comfy/utils.py index edd80cebe..e0a94e2e1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -27,6 +27,7 @@ from PIL import Image import logging import itertools from torch.nn.functional import interpolate +from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args, enables_dynamic_vram import json @@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) +def model_trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange(*args, **kwargs) + + pbar = trange(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + PROGRESS_BAR_ENABLED = True def set_progress_bar_enabled(enabled): global PROGRESS_BAR_ENABLED diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py index d4aaf98ca..b9d5ec7d9 100644 --- a/comfy/weight_adapter/bypass.py +++ b/comfy/weight_adapter/bypass.py @@ -21,6 +21,7 @@ from typing import Optional, Union import torch import torch.nn as nn +import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection @@ -181,18 +182,21 @@ class BypassForwardHook: ) return # Already injected - # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward - device = None + # Move adapter weights to compute device (GPU) + # Use get_torch_device() instead of module.weight.device because + # with offloading, module weights may be on CPU while compute happens on GPU + device = comfy.model_management.get_torch_device() + + # Get dtype from module weight if available dtype = None if hasattr(self.module, "weight") and self.module.weight is not None: - device = self.module.weight.device dtype = self.module.weight.dtype - elif hasattr(self.module, "W_q"): # Quantized layers might use different attr - device = self.module.W_q.device - dtype = self.module.W_q.dtype - if device is not None: - self._move_adapter_weights_to_device(device, dtype) + # Only use dtype if it's a standard float type, not quantized + if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = None + + self._move_adapter_weights_to_device(device, dtype) self.original_forward = self.module.forward self.module.forward = self._bypass_forward diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 013e71cc8..83a581c5d 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -30,6 +30,30 @@ from comfy_api_nodes.util import ( validate_image_dimensions, ) +_EUR_TO_USD = 1.19 + + +def _tier_price_eur(megapixels: float) -> float: + """Price in EUR for a single Magnific upscaling step based on input megapixels.""" + if megapixels <= 1.3: + return 0.143 + if megapixels <= 3.0: + return 0.286 + if megapixels <= 6.4: + return 0.429 + return 1.716 + + +def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float: + """Calculate total Magnific upscale price in USD for given input dimensions and scale factor.""" + num_steps = int(math.log2(scale)) + total_eur = 0.0 + pixels = width * height + for _ in range(num_steps): + total_eur += _tier_price_eur(pixels / 1_000_000) + pixels *= 4 + return round(total_eur * _EUR_TO_USD, 2) + class MagnificImageUpscalerCreativeNode(IO.ComfyNode): @classmethod @@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]), expr=""" ( - $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; - {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + $ad := widgets.auto_downscale; + $mins := $ad + ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515} + : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } ) """, ), @@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): f"Use a smaller input image or lower scale factor." ) + final_height, final_width = get_image_dimensions(image) + actual_scale = int(scale_factor.rstrip("x")) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale) + initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), @@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), response_model=TaskResponse, status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, poll_interval=10.0, max_poll_attempts=480, ) @@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), expr=""" ( - $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; - {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } ) """, ), @@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): f"Use a smaller input image or lower scale factor." ) + final_height, final_width = get_image_dimensions(image) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale) + initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), @@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), response_model=TaskResponse, status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, poll_interval=10.0, max_poll_attempts=480, ) @@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ - # MagnificImageUpscalerCreativeNode, - # MagnificImageUpscalerPreciseV2Node, + MagnificImageUpscalerCreativeNode, + MagnificImageUpscalerPreciseV2Node, MagnificImageStyleTransferNode, MagnificImageRelightNode, MagnificImageSkinEnhancerNode, diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 8a1259506..391748e7a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -143,9 +143,9 @@ async def poll_op( poll_interval: float = 5.0, max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, @@ -240,9 +240,9 @@ async def poll_op_raw( poll_interval: float = 5.0, max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 024a89391..630eedc9f 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -4,6 +4,7 @@ import os import numpy as np import safetensors import torch +import torch.nn as nn import torch.utils.checkpoint from tqdm.auto import trange from PIL import Image, ImageDraw, ImageFont @@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): """ CFGGuider with modifications for training specific logic """ + + def __init__(self, *args, offloading=False, **kwargs): + super().__init__(*args, **kwargs) + self.offloading = offloading + def outer_sample( self, noise, @@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): noise.shape, self.conds, self.model_options, - force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded + force_full_load=not self.offloading, + force_offload=self.offloading, ) ) + torch.cuda.empty_cache() device = self.model_patcher.load_device if denoise_mask is not None: @@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward( return result -def patch(m): +def find_modules_at_depth( + model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None +) -> list[nn.Module]: + """ + Find modules at a specific depth level for gradient checkpointing. + + Args: + model: The model to search + depth: Target depth level (1 = top-level blocks, 2 = their children, etc.) + result: Accumulator for results + current_depth: Current recursion depth + name: Current module name for logging + + Returns: + List of modules at the target depth + """ + if result is None: + result = [] + name = name or "root" + + # Skip container modules (they don't have meaningful forward) + is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict)) + has_forward = hasattr(model, "forward") and not is_container + + if has_forward: + current_depth += 1 + if current_depth == depth: + result.append(model) + logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})") + return result + + # Recurse into children + for next_name, child in model.named_children(): + find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}") + + return result + + +class OffloadCheckpointFunction(torch.autograd.Function): + """ + Gradient checkpointing that works with weight offloading. + + Forward: no_grad -> compute -> weights can be freed + Backward: enable_grad -> recompute -> backward -> weights can be freed + + For single input, single output modules (Linear, Conv*). + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, forward_fn): + ctx.save_for_backward(x) + ctx.forward_fn = forward_fn + with torch.no_grad(): + return forward_fn(x) + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + x, = ctx.saved_tensors + forward_fn = ctx.forward_fn + + # Clear context early + ctx.forward_fn = None + + with torch.enable_grad(): + x_detached = x.detach().requires_grad_(True) + y = forward_fn(x_detached) + y.backward(grad_out) + grad_x = x_detached.grad + + # Explicit cleanup + del y, x_detached, forward_fn + + return grad_x, None + + +def patch(m, offloading=False): if not hasattr(m, "forward"): return org_forward = m.forward - def fwd(args, kwargs): - return org_forward(*args, **kwargs) + # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output) + if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + def checkpointing_fwd(x): + return OffloadCheckpointFunction.apply(x, org_forward) + # Branch 2: Others -> standard checkpoint + else: + def fwd(args, kwargs): + return org_forward(*args, **kwargs) - def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) m.org_forward = org_forward m.forward = checkpointing_fwd @@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode): default=True, tooltip="Use gradient checkpointing for training.", ), + io.Int.Input( + "checkpoint_depth", + default=1, + min=1, + max=5, + tooltip="Depth level for gradient checkpointing.", + ), + io.Boolean.Input( + "offloading", + default=False, + tooltip="Depth level for gradient checkpointing.", + ), io.Combo.Input( "existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], @@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode): lora_dtype, algorithm, gradient_checkpointing, + checkpoint_depth, + offloading, existing_lora, bucket_mode, bypass_mode, @@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] + offloading = offloading[0] + checkpoint_depth = checkpoint_depth[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] @@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode): # Setup gradient checkpointing if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward( - mp.model.diffusion_model - ): - patch(m) + modules_to_patch = find_modules_at_depth( + mp.model.diffusion_model, depth=checkpoint_depth + ) + logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}") + for m in modules_to_patch: + patch(m, offloading=offloading) torch.cuda.empty_cache() # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=True + [mp], memory_required=1e20, force_full_load=not offloading ) torch.cuda.empty_cache() @@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode): ) # Setup guider - guider = TrainGuider(mp) + guider = TrainGuider(mp, offloading=offloading) guider.set_conds(positive) # Inject bypass hooks if bypass mode is enabled @@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode): # Run training loop try: + comfy.model_management.in_training = True _run_training_loop( guider, train_sampler, @@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + comfy.model_management.in_training = False # Eject bypass hooks if they were injected if bypass_injections is not None: for injection in bypass_injections: @@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode): unpatch(m) del train_sampler, optimizer - # Finalize adapters + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype).detach() + for adapter in all_weight_adapters: adapter.requires_grad_(False) - - for param in lora_sd: - lora_sd[param] = lora_sd[param].to(lora_dtype) + del adapter + del all_weight_adapters # mp in train node is highly specialized for training # use it in inference will result in bad behavior so we don't return it return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode):# +class LoraModelLoader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( @@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):# max=100.0, tooltip="How strongly to modify the diffusion model. This value can be negative.", ), + io.Boolean.Input( + "bypass", + default=False, + tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.", + ), ], outputs=[ io.Model.Output( @@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):# ) @classmethod - def execute(cls, model, lora, strength_model): + def execute(cls, model, lora, strength_model, bypass=False): if strength_model == 0: return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models( - model, None, lora, strength_model, 0 - ) + if bypass: + model_lora, _ = comfy.sd.load_bypass_lora_for_models( + model, None, lora, strength_model, 0 + ) + else: + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) return io.NodeOutput(model_lora)