mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Merge branch 'comfyanonymous:master' into feature/custom_nodes-envvar
This commit is contained in:
commit
4425247ad4
@ -210,7 +210,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@ -222,10 +222,22 @@ class Flux(nn.Module):
|
|||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
steps_h = h_len
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
index += rope_options.get("shift_t", 0.0)
|
||||||
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
@ -241,7 +253,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
img, img_ids = self.process_img(x)
|
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||||
img_tokens = img.shape[1]
|
img_tokens = img.shape[1]
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
|
|||||||
@ -503,7 +503,11 @@ class LoadedModel:
|
|||||||
use_more_vram = lowvram_model_memory
|
use_more_vram = lowvram_model_memory
|
||||||
if use_more_vram == 0:
|
if use_more_vram == 0:
|
||||||
use_more_vram = 1e32
|
use_more_vram = 1e32
|
||||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
if use_more_vram > 0:
|
||||||
|
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||||
|
else:
|
||||||
|
self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights)
|
||||||
|
|
||||||
real_model = self.model.model
|
real_model = self.model.model
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||||
@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||||
|
|
||||||
|
if lowvram_model_memory == 0:
|
||||||
|
lowvram_model_memory = 0.1
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 0.1
|
lowvram_model_memory = 0.1
|
||||||
@ -1129,13 +1136,18 @@ def unpin_memory(tensor):
|
|||||||
if not is_device_cpu(tensor.device):
|
if not is_device_cpu(tensor.device):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not tensor.is_pinned():
|
ptr = tensor.data_ptr()
|
||||||
#NOTE: Cuda does detect when a tensor is already pinned and would
|
size = tensor.numel() * tensor.element_size()
|
||||||
#error below, but there are proven cases where this also queues an error
|
|
||||||
#on the GPU async. So dont trust the CUDA API and guard here
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||||
|
if size_stored is None:
|
||||||
|
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if size != size_stored:
|
||||||
|
logging.warning("Size of pinned tensor changed")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
|
||||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||||
if len(PINNED_MEMORY) == 0:
|
if len(PINNED_MEMORY) == 0:
|
||||||
|
|||||||
@ -843,7 +843,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
hooks_unpatched = False
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
@ -887,13 +887,19 @@ class ModelPatcher:
|
|||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
if force_patch_weights:
|
||||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
self.patch_weight_to_device(weight_key)
|
||||||
patch_counter += 1
|
else:
|
||||||
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||||
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
if force_patch_weights:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
self.patch_weight_to_device(bias_key)
|
||||||
patch_counter += 1
|
else:
|
||||||
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||||
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
if cast_weight:
|
if cast_weight:
|
||||||
@ -909,6 +915,7 @@ class ModelPatcher:
|
|||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
|
|||||||
@ -110,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
weight = weight.to(dtype=dtype)
|
if weight_has_function or weight.dtype != dtype:
|
||||||
if weight_has_function:
|
|
||||||
with wf_context:
|
with wf_context:
|
||||||
|
weight = weight.to(dtype=dtype)
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user