Merge upstream/master, keep local README.md

This commit is contained in:
GitHub Actions 2026-02-12 00:47:23 +00:00
commit 602ca59147
15 changed files with 365 additions and 108 deletions

View File

@ -7,6 +7,8 @@ on:
jobs: jobs:
send-webhook: send-webhook:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps: steps:
- name: Send release webhook - name: Send release webhook
env: env:
@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error --fail --silent --show-error
echo "✅ Release webhook sent successfully" echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@ -1,12 +1,11 @@
import math import math
import time
from functools import partial from functools import partial
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
import torchsde import torchsde
from tqdm.auto import trange as trange_, tqdm from tqdm.auto import tqdm
from . import utils from . import utils
from . import deis from . import deis
@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
import comfy.memory_management import comfy.memory_management
from comfy.utils import model_trange as trange
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])

View File

@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids): def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None: if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids) out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
else: else:
return text_embeds return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device) return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try: try:
import comfy.quant_ops import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope q_apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1 q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except: except:
logging.warning("No comfy kitchen, using old apply_rope functions.") logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor): apply_rope = _apply_rope
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) apply_rope1 = _apply_rope1
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"] device = kwargs["device"]
if cross_attn is not None: if cross_attn is not None:
if t5xxl_ids is not None: if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None: if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out

View File

@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types(): def get_supported_float8_types():
float8_types = [] float8_types = []
try: try:
@ -1208,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None: if signature is not None:
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0] if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature weight._v_signature = signature
#Send it over #Send it over
v_tensor.copy_(weight, non_blocking=non_blocking) v_tensor.copy_(weight, non_blocking=non_blocking)

View File

@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function) setattr(m, param_key + "_function", weight_function)
geometry = weight geometry = weight
if not isinstance(weight, QuantizedTensor): if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry) return comfy.memory_management.vram_aligned_size(geometry)
@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"): if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size) m._v = vbar.alloc(v_weight_size)
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size allocated_size += v_weight_size
else: else:
@ -1552,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key weight.seed_key = key
set_dirty(weight, dirty) set_dirty(weight, dirty)
geometry = weight geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size() weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"): if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size) weight._v = vbar.alloc(weight_size)
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
allocated_size += weight_size allocated_size += weight_size
vbar.set_watermark_limit(allocated_size) vbar.set_watermark_limit(allocated_size)

View File

@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None offload_stream = None
xfer_dest = None xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident: if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
xfer_source = [ s.weight, s.bias ] xfer_source = [ s.weight, s.bias ]
@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast) post_cast.copy_(pre_cast)
xfer_dest = cast_dest xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0] weight = params[0]
bias = params[1] bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight): def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight) weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None: if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol #FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None) return weight, bias, (offload_stream, device if signature is not None else None, None)

View File

@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling, _prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
) )
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip from comfy import sd1_clip
import torch import torch
import math import math
from tqdm.auto import trange
import yaml import yaml
import comfy.utils import comfy.utils
@ -52,7 +51,7 @@ def sample_manual_loop_no_classes(
progress_bar = comfy.utils.ProgressBar(max_new_tokens) progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in trange(max_new_tokens, desc="LM sampling"): for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1] next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2] past_key_values = outputs[2]

View File

@ -27,6 +27,7 @@ from PIL import Image
import logging import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
import json import json
@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED global PROGRESS_BAR_ENABLED

View File

@ -21,6 +21,7 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection from comfy.patcher_extension import PatcherInjection
@ -181,18 +182,21 @@ class BypassForwardHook:
) )
return # Already injected return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward # Move adapter weights to compute device (GPU)
device = None # Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
dtype = None dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None: if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None: # Only use dtype if it's a standard float type, not quantized
self._move_adapter_weights_to_device(device, dtype) if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward self.original_forward = self.module.forward
self.module.forward = self._bypass_forward self.module.forward = self._bypass_forward

View File

@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions, validate_image_dimensions,
) )
_EUR_TO_USD = 1.19
def _tier_price_eur(megapixels: float) -> float:
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
if megapixels <= 1.3:
return 0.143
if megapixels <= 3.0:
return 0.286
if megapixels <= 6.4:
return 0.429
return 1.716
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
num_steps = int(math.log2(scale))
total_eur = 0.0
pixels = width * height
for _ in range(num_steps):
total_eur += _tier_price_eur(pixels / 1_000_000)
pixels *= 4
return round(total_eur * _EUR_TO_USD, 2)
class MagnificImageUpscalerCreativeNode(IO.ComfyNode): class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod @classmethod
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr=""" expr="""
( (
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657; $ad := widgets.auto_downscale;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max} $mins := $ad
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
) )
""", """,
), ),
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor." f"Use a smaller input image or lower scale factor."
) )
final_height, final_width = get_image_dimensions(image)
actual_scale = int(scale_factor.rstrip("x"))
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
initial_res = await sync_op( initial_res = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480, max_poll_attempts=480,
) )
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr=""" expr="""
( (
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657; $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max} $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
) )
""", """,
), ),
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor." f"Use a smaller input image or lower scale factor."
) )
final_height, final_width = get_image_dimensions(image)
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
initial_res = await sync_op( initial_res = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480, max_poll_attempts=480,
) )
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
# MagnificImageUpscalerCreativeNode, MagnificImageUpscalerCreativeNode,
# MagnificImageUpscalerPreciseV2Node, MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode, MagnificImageStyleTransferNode,
MagnificImageRelightNode, MagnificImageRelightNode,
MagnificImageSkinEnhancerNode, MagnificImageSkinEnhancerNode,

View File

@ -143,9 +143,9 @@ async def poll_op(
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 160, max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
@ -240,9 +240,9 @@ async def poll_op_raw(
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 160, max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,

View File

@ -4,6 +4,7 @@ import os
import numpy as np import numpy as np
import safetensors import safetensors
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import trange from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
""" """
CFGGuider with modifications for training specific logic CFGGuider with modifications for training specific logic
""" """
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample( def outer_sample(
self, self,
noise, noise,
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape, noise.shape,
self.conds, self.conds,
self.model_options, self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded force_full_load=not self.offloading,
force_offload=self.offloading,
) )
) )
torch.cuda.empty_cache()
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None: if denoise_mask is not None:
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result return result
def patch(m): def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"): if not hasattr(m, "forward"):
return return
org_forward = m.forward org_forward = m.forward
def fwd(args, kwargs): # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
return org_forward(*args, **kwargs) if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs): def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward m.org_forward = org_forward
m.forward = checkpointing_fwd m.forward = checkpointing_fwd
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True, default=True,
tooltip="Use gradient checkpointing for training.", tooltip="Use gradient checkpointing for training.",
), ),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"], options=folder_paths.get_filename_list("loras") + ["[None]"],
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype, lora_dtype,
algorithm, algorithm,
gradient_checkpointing, gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
algorithm = algorithm[0] algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0] gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing # Setup gradient checkpointing
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward( modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model mp.model.diffusion_model, depth=checkpoint_depth
): )
patch(m) logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True [mp], memory_required=1e20, force_full_load=not offloading
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = TrainGuider(mp) guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive) guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop # Run training loop
try: try:
comfy.model_management.in_training = True
_run_training_loop( _run_training_loop(
guider, guider,
train_sampler, train_sampler,
@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res, multi_res,
) )
finally: finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected # Eject bypass hooks if they were injected
if bypass_injections is not None: if bypass_injections is not None:
for injection in bypass_injections: for injection in bypass_injections:
@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m) unpatch(m)
del train_sampler, optimizer del train_sampler, optimizer
# Finalize adapters for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters: for adapter in all_weight_adapters:
adapter.requires_grad_(False) adapter.requires_grad_(False)
del adapter
for param in lora_sd: del all_weight_adapters
lora_sd[param] = lora_sd[param].to(lora_dtype)
# mp in train node is highly specialized for training # mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it # use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):# class LoraModelLoader(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0, max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.", tooltip="How strongly to modify the diffusion model. This value can be negative.",
), ),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
], ],
outputs=[ outputs=[
io.Model.Output( io.Model.Output(
@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
) )
@classmethod @classmethod
def execute(cls, model, lora, strength_model): def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0: if strength_model == 0:
return io.NodeOutput(model) return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models( if bypass:
model, None, lora, strength_model, 0 model_lora, _ = comfy.sd.load_bypass_lora_for_models(
) model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora) return io.NodeOutput(model_lora)