mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
8744ebb4a1
@ -195,20 +195,50 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, h_orig, w_orig = x.shape
|
||||||
|
patch_size = self.patch_size
|
||||||
|
|
||||||
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
|
img, img_ids = self.process_img(x)
|
||||||
|
img_tokens = img.shape[1]
|
||||||
|
if ref_latents is not None:
|
||||||
|
h = 0
|
||||||
|
w = 0
|
||||||
|
for ref in ref_latents:
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
|
||||||
|
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
img = torch.cat([img, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
out = out[:, :img_tokens]
|
||||||
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||||
|
|||||||
@ -816,6 +816,7 @@ class PixArt(BaseModel):
|
|||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||||
|
self.memory_usage_factor_conds = ("kontext",)
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -876,8 +877,23 @@ class Flux(BaseModel):
|
|||||||
guidance = kwargs.get("guidance", 3.5)
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||||
|
|||||||
@ -1290,6 +1290,13 @@ def supports_fp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def extended_fp16_support():
|
||||||
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
|
if torch_version_numeric < (2, 7):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
|||||||
@ -1197,11 +1197,16 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
if comfy.model_management.extended_fp16_support():
|
||||||
|
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Omnigen2(self, device=device)
|
out = model_base.Omnigen2(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
class CLIPTextEncodeFlux:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -56,8 +57,52 @@ class FluxDisableGuidance:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
|
(672, 1568),
|
||||||
|
(688, 1504),
|
||||||
|
(720, 1456),
|
||||||
|
(752, 1392),
|
||||||
|
(800, 1328),
|
||||||
|
(832, 1248),
|
||||||
|
(880, 1184),
|
||||||
|
(944, 1104),
|
||||||
|
(1024, 1024),
|
||||||
|
(1104, 944),
|
||||||
|
(1184, 880),
|
||||||
|
(1248, 832),
|
||||||
|
(1328, 800),
|
||||||
|
(1392, 752),
|
||||||
|
(1456, 720),
|
||||||
|
(1504, 688),
|
||||||
|
(1568, 672),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKontextImageScale:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"image": ("IMAGE", ),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "scale"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/flux"
|
||||||
|
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
||||||
|
|
||||||
|
def scale(self, image):
|
||||||
|
width = image.shape[2]
|
||||||
|
height = image.shape[1]
|
||||||
|
aspect_ratio = width / height
|
||||||
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||||
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
|
return (image, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
"FluxDisableGuidance": FluxDisableGuidance,
|
||||||
|
"FluxKontextImageScale": FluxKontextImageScale,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.41"
|
__version__ = "0.3.42"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.41"
|
version = "0.3.42"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.22.2
|
comfyui-frontend-package==1.23.4
|
||||||
comfyui-workflow-templates==0.1.29
|
comfyui-workflow-templates==0.1.30
|
||||||
comfyui-embedded-docs==0.2.2
|
comfyui-embedded-docs==0.2.3
|
||||||
comfyui_manager
|
comfyui_manager
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -8,7 +8,7 @@ torchvision
|
|||||||
torchaudio
|
torchaudio
|
||||||
numpy>=1.25.0
|
numpy>=1.25.0
|
||||||
einops
|
einops
|
||||||
transformers>=4.28.1
|
transformers>=4.37.2
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
sentencepiece
|
sentencepiece
|
||||||
safetensors>=0.4.2
|
safetensors>=0.4.2
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user