mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
dynamic_vram: Training fixes (#12442)
This commit is contained in:
parent
e03fe8b591
commit
8902907d7a
@ -1561,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
vbar.set_watermark_limit(allocated_size)
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
@ -1601,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None, 1e32)
|
self.partially_unload(None, 1e32)
|
||||||
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
"offloading",
|
"offloading",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Depth level for gradient checkpointing.",
|
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
if mp.is_dynamic():
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||||
|
bypass_mode = True
|
||||||
|
offloading = True
|
||||||
|
elif offloading:
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, dtype, bucket_mode
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user