mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 22:42:35 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
602ca59147
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@ -7,6 +7,8 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
send-webhook:
|
send-webhook:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Send release webhook
|
- name: Send release webhook
|
||||||
env:
|
env:
|
||||||
@ -106,3 +108,37 @@ jobs:
|
|||||||
--fail --silent --show-error
|
--fail --silent --show-error
|
||||||
|
|
||||||
echo "✅ Release webhook sent successfully"
|
echo "✅ Release webhook sent successfully"
|
||||||
|
|
||||||
|
- name: Send repository dispatch to desktop
|
||||||
|
env:
|
||||||
|
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
|
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||||
|
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||||
|
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PAYLOAD="$(jq -n \
|
||||||
|
--arg release_tag "$RELEASE_TAG" \
|
||||||
|
--arg release_url "$RELEASE_URL" \
|
||||||
|
'{
|
||||||
|
event_type: "comfyui_release_published",
|
||||||
|
client_payload: {
|
||||||
|
release_tag: $release_tag,
|
||||||
|
release_url: $release_url
|
||||||
|
}
|
||||||
|
}')"
|
||||||
|
|
||||||
|
curl -fsSL \
|
||||||
|
-X POST \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||||
|
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||||
|
-d "$PAYLOAD"
|
||||||
|
|
||||||
|
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange as trange_, tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
@ -15,34 +14,7 @@ import comfy.model_patcher
|
|||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
from comfy.utils import model_trange as trange
|
||||||
|
|
||||||
def trange(*args, **kwargs):
|
|
||||||
if comfy.memory_management.aimdo_allocator is None:
|
|
||||||
return trange_(*args, **kwargs)
|
|
||||||
|
|
||||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
|
||||||
pbar._i = 0
|
|
||||||
pbar.set_postfix_str(" Model Initializing ... ")
|
|
||||||
|
|
||||||
_update = pbar.update
|
|
||||||
|
|
||||||
def warmup_update(n=1):
|
|
||||||
pbar._i += 1
|
|
||||||
if pbar._i == 1:
|
|
||||||
pbar.i1_time = time.time()
|
|
||||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
|
||||||
elif pbar._i == 2:
|
|
||||||
#bring forward the effective start time based the the diff between first and second iteration
|
|
||||||
#to attempt to remove load overhead from the final step rate estimate.
|
|
||||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
|
||||||
pbar.set_postfix_str("")
|
|
||||||
|
|
||||||
_update(n)
|
|
||||||
|
|
||||||
pbar.update = warmup_update
|
|
||||||
return pbar
|
|
||||||
|
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
|
|||||||
@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||||
|
|
||||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||||
if text_ids is not None:
|
if text_ids is not None:
|
||||||
return self.llm_adapter(text_embeds, text_ids)
|
out = self.llm_adapter(text_embeds, text_ids)
|
||||||
|
if t5xxl_weights is not None:
|
||||||
|
out = out * t5xxl_weights
|
||||||
|
|
||||||
|
if out.shape[1] < 512:
|
||||||
|
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||||
|
return out
|
||||||
else:
|
else:
|
||||||
return text_embeds
|
return text_embeds
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, **kwargs):
|
||||||
|
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||||
|
if t5xxl_ids is not None:
|
||||||
|
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||||
|
return super().forward(x, timesteps, context, **kwargs)
|
||||||
|
|||||||
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope(xq, xk, freqs_cis)
|
||||||
|
else:
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
def apply_rope1(x, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope1(x, freqs_cis)
|
||||||
|
else:
|
||||||
|
return q_apply_rope1(x, freqs_cis)
|
||||||
except:
|
except:
|
||||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
apply_rope = _apply_rope
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
apply_rope1 = _apply_rope1
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
|||||||
@ -1160,12 +1160,16 @@ class Anima(BaseModel):
|
|||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if t5xxl_ids is not None:
|
if t5xxl_ids is not None:
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
|
||||||
if t5xxl_weights is not None:
|
if t5xxl_weights is not None:
|
||||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||||
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||||
|
else:
|
||||||
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|
||||||
if cross_attn.shape[1] < 512:
|
|
||||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Training Related State
|
||||||
|
in_training = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
@ -1208,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
|
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
v_tensor = weight._v_tensor
|
||||||
|
else:
|
||||||
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
|
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||||
|
weight._v_tensor = v_tensor
|
||||||
weight._v_signature = signature
|
weight._v_signature = signature
|
||||||
#Send it over
|
#Send it over
|
||||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||||
|
|||||||
@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
setattr(m, param_key + "_function", weight_function)
|
setattr(m, param_key + "_function", weight_function)
|
||||||
geometry = weight
|
geometry = weight
|
||||||
if not isinstance(weight, QuantizedTensor):
|
if not isinstance(weight, QuantizedTensor):
|
||||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
return comfy.memory_management.vram_aligned_size(geometry)
|
return comfy.memory_management.vram_aligned_size(geometry)
|
||||||
@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
if vbar is not None and not hasattr(m, "_v"):
|
if vbar is not None and not hasattr(m, "_v"):
|
||||||
m._v = vbar.alloc(v_weight_size)
|
m._v = vbar.alloc(v_weight_size)
|
||||||
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
|
|
||||||
allocated_size += v_weight_size
|
allocated_size += v_weight_size
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -1552,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight.seed_key = key
|
weight.seed_key = key
|
||||||
set_dirty(weight, dirty)
|
set_dirty(weight, dirty)
|
||||||
geometry = weight
|
geometry = weight
|
||||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
weight_size = geometry.numel() * geometry.element_size()
|
weight_size = geometry.numel() * geometry.element_size()
|
||||||
if vbar is not None and not hasattr(weight, "_v"):
|
if vbar is not None and not hasattr(weight, "_v"):
|
||||||
weight._v = vbar.alloc(weight_size)
|
weight._v = vbar.alloc(weight_size)
|
||||||
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
|
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
vbar.set_watermark_limit(allocated_size)
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|||||||
21
comfy/ops.py
21
comfy/ops.py
@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
if signature is not None:
|
|
||||||
xfer_dest = s._v_tensor
|
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
if signature is not None:
|
||||||
|
if resident:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||||
|
|
||||||
if not resident:
|
if not resident:
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
post_cast.copy_(pre_cast)
|
post_cast.copy_(pre_cast)
|
||||||
xfer_dest = cast_dest
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
|
if signature is not None:
|
||||||
|
s._v_weight = weight
|
||||||
|
s._v_bias = bias
|
||||||
|
s._v_signature=signature
|
||||||
|
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||||
s._v_signature=signature
|
|
||||||
|
|
||||||
#FIXME: weird offload return protocol
|
#FIXME: weird offload return protocol
|
||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|||||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
memory_required = 1e20
|
||||||
|
minimum_memory_required = None
|
||||||
|
else:
|
||||||
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
|
memory_required += inference_memory
|
||||||
|
minimum_memory_required += inference_memory
|
||||||
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import comfy.text_encoders.llama
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
|
||||||
import yaml
|
import yaml
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -52,7 +51,7 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||||
|
|
||||||
for step in trange(max_new_tokens, desc="LM sampling"):
|
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||||
past_key_values = outputs[2]
|
past_key_values = outputs[2]
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from PIL import Image
|
|||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
import json
|
import json
|
||||||
@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
|
def model_trange(*args, **kwargs):
|
||||||
|
if comfy.memory_management.aimdo_allocator is None:
|
||||||
|
return trange(*args, **kwargs)
|
||||||
|
|
||||||
|
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||||
|
pbar._i = 0
|
||||||
|
pbar.set_postfix_str(" Model Initializing ... ")
|
||||||
|
|
||||||
|
_update = pbar.update
|
||||||
|
|
||||||
|
def warmup_update(n=1):
|
||||||
|
pbar._i += 1
|
||||||
|
if pbar._i == 1:
|
||||||
|
pbar.i1_time = time.time()
|
||||||
|
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||||
|
elif pbar._i == 2:
|
||||||
|
#bring forward the effective start time based the the diff between first and second iteration
|
||||||
|
#to attempt to remove load overhead from the final step rate estimate.
|
||||||
|
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||||
|
pbar.set_postfix_str("")
|
||||||
|
|
||||||
|
_update(n)
|
||||||
|
|
||||||
|
pbar.update = warmup_update
|
||||||
|
return pbar
|
||||||
|
|
||||||
PROGRESS_BAR_ENABLED = True
|
PROGRESS_BAR_ENABLED = True
|
||||||
def set_progress_bar_enabled(enabled):
|
def set_progress_bar_enabled(enabled):
|
||||||
global PROGRESS_BAR_ENABLED
|
global PROGRESS_BAR_ENABLED
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from comfy.patcher_extension import PatcherInjection
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
|||||||
)
|
)
|
||||||
return # Already injected
|
return # Already injected
|
||||||
|
|
||||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
# Move adapter weights to compute device (GPU)
|
||||||
device = None
|
# Use get_torch_device() instead of module.weight.device because
|
||||||
|
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
# Get dtype from module weight if available
|
||||||
dtype = None
|
dtype = None
|
||||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
device = self.module.weight.device
|
|
||||||
dtype = self.module.weight.dtype
|
dtype = self.module.weight.dtype
|
||||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
|
||||||
device = self.module.W_q.device
|
|
||||||
dtype = self.module.W_q.dtype
|
|
||||||
|
|
||||||
if device is not None:
|
# Only use dtype if it's a standard float type, not quantized
|
||||||
self._move_adapter_weights_to_device(device, dtype)
|
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
self.original_forward = self.module.forward
|
self.original_forward = self.module.forward
|
||||||
self.module.forward = self._bypass_forward
|
self.module.forward = self._bypass_forward
|
||||||
|
|||||||
@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
|
|||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_EUR_TO_USD = 1.19
|
||||||
|
|
||||||
|
|
||||||
|
def _tier_price_eur(megapixels: float) -> float:
|
||||||
|
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||||
|
if megapixels <= 1.3:
|
||||||
|
return 0.143
|
||||||
|
if megapixels <= 3.0:
|
||||||
|
return 0.286
|
||||||
|
if megapixels <= 6.4:
|
||||||
|
return 0.429
|
||||||
|
return 1.716
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||||
|
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||||
|
num_steps = int(math.log2(scale))
|
||||||
|
total_eur = 0.0
|
||||||
|
pixels = width * height
|
||||||
|
for _ in range(num_steps):
|
||||||
|
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||||
|
pixels *= 4
|
||||||
|
return round(total_eur * _EUR_TO_USD, 2)
|
||||||
|
|
||||||
|
|
||||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
$ad := widgets.auto_downscale;
|
||||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
$mins := $ad
|
||||||
|
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||||
|
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||||
|
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||||
|
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||||
|
"format": { "approximate": true }
|
||||||
|
}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
f"Use a smaller input image or lower scale factor."
|
f"Use a smaller input image or lower scale factor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_height, final_width = get_image_dimensions(image)
|
||||||
|
actual_scale = int(scale_factor.rstrip("x"))
|
||||||
|
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||||
|
|
||||||
initial_res = await sync_op(
|
initial_res = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||||
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||||
response_model=TaskResponse,
|
response_model=TaskResponse,
|
||||||
status_extractor=lambda x: x.status,
|
status_extractor=lambda x: x.status,
|
||||||
|
price_extractor=lambda _: price_usd,
|
||||||
poll_interval=10.0,
|
poll_interval=10.0,
|
||||||
max_poll_attempts=480,
|
max_poll_attempts=480,
|
||||||
)
|
)
|
||||||
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||||
|
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||||
|
"format": { "approximate": true }
|
||||||
|
}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
f"Use a smaller input image or lower scale factor."
|
f"Use a smaller input image or lower scale factor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_height, final_width = get_image_dimensions(image)
|
||||||
|
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||||
|
|
||||||
initial_res = await sync_op(
|
initial_res = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||||
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||||
response_model=TaskResponse,
|
response_model=TaskResponse,
|
||||||
status_extractor=lambda x: x.status,
|
status_extractor=lambda x: x.status,
|
||||||
|
price_extractor=lambda _: price_usd,
|
||||||
poll_interval=10.0,
|
poll_interval=10.0,
|
||||||
max_poll_attempts=480,
|
max_poll_attempts=480,
|
||||||
)
|
)
|
||||||
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
# MagnificImageUpscalerCreativeNode,
|
MagnificImageUpscalerCreativeNode,
|
||||||
# MagnificImageUpscalerPreciseV2Node,
|
MagnificImageUpscalerPreciseV2Node,
|
||||||
MagnificImageStyleTransferNode,
|
MagnificImageStyleTransferNode,
|
||||||
MagnificImageRelightNode,
|
MagnificImageRelightNode,
|
||||||
MagnificImageSkinEnhancerNode,
|
MagnificImageSkinEnhancerNode,
|
||||||
|
|||||||
@ -143,9 +143,9 @@ async def poll_op(
|
|||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 160,
|
max_poll_attempts: int = 160,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 10,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 1.4,
|
||||||
estimated_duration: int | None = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
@ -240,9 +240,9 @@ async def poll_op_raw(
|
|||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 160,
|
max_poll_attempts: int = 160,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 10,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 1.4,
|
||||||
estimated_duration: int | None = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
"""
|
"""
|
||||||
CFGGuider with modifications for training specific logic
|
CFGGuider with modifications for training specific logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, offloading=False, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.offloading = offloading
|
||||||
|
|
||||||
def outer_sample(
|
def outer_sample(
|
||||||
self,
|
self,
|
||||||
noise,
|
noise,
|
||||||
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
noise.shape,
|
noise.shape,
|
||||||
self.conds,
|
self.conds,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
force_full_load=not self.offloading,
|
||||||
|
force_offload=self.offloading,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def patch(m):
|
def find_modules_at_depth(
|
||||||
|
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||||
|
) -> list[nn.Module]:
|
||||||
|
"""
|
||||||
|
Find modules at a specific depth level for gradient checkpointing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to search
|
||||||
|
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||||
|
result: Accumulator for results
|
||||||
|
current_depth: Current recursion depth
|
||||||
|
name: Current module name for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of modules at the target depth
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
name = name or "root"
|
||||||
|
|
||||||
|
# Skip container modules (they don't have meaningful forward)
|
||||||
|
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||||
|
has_forward = hasattr(model, "forward") and not is_container
|
||||||
|
|
||||||
|
if has_forward:
|
||||||
|
current_depth += 1
|
||||||
|
if current_depth == depth:
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Recurse into children
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient checkpointing that works with weight offloading.
|
||||||
|
|
||||||
|
Forward: no_grad -> compute -> weights can be freed
|
||||||
|
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||||
|
|
||||||
|
For single input, single output modules (Linear, Conv*).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.forward_fn = forward_fn
|
||||||
|
with torch.no_grad():
|
||||||
|
return forward_fn(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out: torch.Tensor):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
forward_fn = ctx.forward_fn
|
||||||
|
|
||||||
|
# Clear context early
|
||||||
|
ctx.forward_fn = None
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_detached = x.detach().requires_grad_(True)
|
||||||
|
y = forward_fn(x_detached)
|
||||||
|
y.backward(grad_out)
|
||||||
|
grad_x = x_detached.grad
|
||||||
|
|
||||||
|
# Explicit cleanup
|
||||||
|
del y, x_detached, forward_fn
|
||||||
|
|
||||||
|
return grad_x, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m, offloading=False):
|
||||||
if not hasattr(m, "forward"):
|
if not hasattr(m, "forward"):
|
||||||
return
|
return
|
||||||
org_forward = m.forward
|
org_forward = m.forward
|
||||||
|
|
||||||
def fwd(args, kwargs):
|
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||||
return org_forward(*args, **kwargs)
|
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
def checkpointing_fwd(x):
|
||||||
|
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||||
|
# Branch 2: Others -> standard checkpoint
|
||||||
|
else:
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
|
||||||
def checkpointing_fwd(*args, **kwargs):
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||||
|
|
||||||
m.org_forward = org_forward
|
m.org_forward = org_forward
|
||||||
m.forward = checkpointing_fwd
|
m.forward = checkpointing_fwd
|
||||||
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=True,
|
default=True,
|
||||||
tooltip="Use gradient checkpointing for training.",
|
tooltip="Use gradient checkpointing for training.",
|
||||||
),
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"checkpoint_depth",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=5,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"offloading",
|
||||||
|
default=False,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype,
|
lora_dtype,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
checkpoint_depth,
|
||||||
|
offloading,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
bypass_mode,
|
bypass_mode,
|
||||||
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
|
offloading = offloading[0]
|
||||||
|
checkpoint_depth = checkpoint_depth[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup gradient checkpointing
|
# Setup gradient checkpointing
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
modules_to_patch = find_modules_at_depth(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model, depth=checkpoint_depth
|
||||||
):
|
)
|
||||||
patch(m)
|
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||||
|
for m in modules_to_patch:
|
||||||
|
patch(m, offloading=offloading)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=not offloading
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp)
|
guider = TrainGuider(mp, offloading=offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
|
comfy.model_management.in_training = True
|
||||||
_run_training_loop(
|
_run_training_loop(
|
||||||
guider,
|
guider,
|
||||||
train_sampler,
|
train_sampler,
|
||||||
@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
multi_res,
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
comfy.model_management.in_training = False
|
||||||
# Eject bypass hooks if they were injected
|
# Eject bypass hooks if they were injected
|
||||||
if bypass_injections is not None:
|
if bypass_injections is not None:
|
||||||
for injection in bypass_injections:
|
for injection in bypass_injections:
|
||||||
@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
# Finalize adapters
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||||
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
del adapter
|
||||||
for param in lora_sd:
|
del all_weight_adapters
|
||||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
|
||||||
|
|
||||||
# mp in train node is highly specialized for training
|
# mp in train node is highly specialized for training
|
||||||
# use it in inference will result in bad behavior so we don't return it
|
# use it in inference will result in bad behavior so we don't return it
|
||||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):#
|
class LoraModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
max=100.0,
|
max=100.0,
|
||||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bypass",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, lora, strength_model):
|
def execute(cls, model, lora, strength_model, bypass=False):
|
||||||
if strength_model == 0:
|
if strength_model == 0:
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
if bypass:
|
||||||
model, None, lora, strength_model, 0
|
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||||
)
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||||
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
return io.NodeOutput(model_lora)
|
return io.NodeOutput(model_lora)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user