mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
index_timestep_zero can be selected in the FluxKontextMultiReferenceLatentMethod now with the display name set to the more generic "Edit Model Reference Method" node.
250 lines
7.7 KiB
Python
250 lines
7.7 KiB
Python
import node_helpers
|
|
import comfy.utils
|
|
from typing_extensions import override
|
|
from comfy_api.latest import ComfyExtension, io
|
|
import comfy.model_management
|
|
import torch
|
|
import math
|
|
import nodes
|
|
|
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="CLIPTextEncodeFlux",
|
|
category="advanced/conditioning/flux",
|
|
inputs=[
|
|
io.Clip.Input("clip"),
|
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
|
|
tokens = clip.tokenize(clip_l)
|
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
|
|
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
|
|
|
encode = execute # TODO: remove
|
|
|
|
class EmptyFlux2LatentImage(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="EmptyFlux2LatentImage",
|
|
display_name="Empty Flux 2 Latent",
|
|
category="latent",
|
|
inputs=[
|
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
|
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
|
return io.NodeOutput({"samples": latent})
|
|
|
|
class FluxGuidance(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="FluxGuidance",
|
|
category="advanced/conditioning/flux",
|
|
inputs=[
|
|
io.Conditioning.Input("conditioning"),
|
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, conditioning, guidance) -> io.NodeOutput:
|
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
|
return io.NodeOutput(c)
|
|
|
|
append = execute # TODO: remove
|
|
|
|
|
|
class FluxDisableGuidance(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="FluxDisableGuidance",
|
|
category="advanced/conditioning/flux",
|
|
description="This node completely disables the guidance embed on Flux and Flux like models",
|
|
inputs=[
|
|
io.Conditioning.Input("conditioning"),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, conditioning) -> io.NodeOutput:
|
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
|
return io.NodeOutput(c)
|
|
|
|
append = execute # TODO: remove
|
|
|
|
|
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
|
(672, 1568),
|
|
(688, 1504),
|
|
(720, 1456),
|
|
(752, 1392),
|
|
(800, 1328),
|
|
(832, 1248),
|
|
(880, 1184),
|
|
(944, 1104),
|
|
(1024, 1024),
|
|
(1104, 944),
|
|
(1184, 880),
|
|
(1248, 832),
|
|
(1328, 800),
|
|
(1392, 752),
|
|
(1456, 720),
|
|
(1504, 688),
|
|
(1568, 672),
|
|
]
|
|
|
|
|
|
class FluxKontextImageScale(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="FluxKontextImageScale",
|
|
category="advanced/conditioning/flux",
|
|
description="This node resizes the image to one that is more optimal for flux kontext.",
|
|
inputs=[
|
|
io.Image.Input("image"),
|
|
],
|
|
outputs=[
|
|
io.Image.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, image) -> io.NodeOutput:
|
|
width = image.shape[2]
|
|
height = image.shape[1]
|
|
aspect_ratio = width / height
|
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
|
return io.NodeOutput(image)
|
|
|
|
scale = execute # TODO: remove
|
|
|
|
|
|
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
|
display_name="Edit Model Reference Method",
|
|
category="advanced/conditioning/flux",
|
|
inputs=[
|
|
io.Conditioning.Input("conditioning"),
|
|
io.Combo.Input(
|
|
"reference_latents_method",
|
|
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(),
|
|
],
|
|
is_experimental=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
|
|
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
|
reference_latents_method = "uxo"
|
|
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
|
return io.NodeOutput(c)
|
|
|
|
append = execute # TODO: remove
|
|
|
|
|
|
def generalized_time_snr_shift(t, mu: float, sigma: float):
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
|
a1, b1 = 8.73809524e-05, 1.89833333
|
|
a2, b2 = 0.00016927, 0.45666666
|
|
|
|
if image_seq_len > 4300:
|
|
mu = a2 * image_seq_len + b2
|
|
return float(mu)
|
|
|
|
m_200 = a2 * image_seq_len + b2
|
|
m_10 = a1 * image_seq_len + b1
|
|
|
|
a = (m_200 - m_10) / 190.0
|
|
b = m_200 - 200.0 * a
|
|
mu = a * num_steps + b
|
|
|
|
return float(mu)
|
|
|
|
|
|
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
|
mu = compute_empirical_mu(image_seq_len, num_steps)
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
|
return timesteps
|
|
|
|
|
|
class Flux2Scheduler(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="Flux2Scheduler",
|
|
category="sampling/custom_sampling/schedulers",
|
|
inputs=[
|
|
io.Int.Input("steps", default=20, min=1, max=4096),
|
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
|
],
|
|
outputs=[
|
|
io.Sigmas.Output(),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, steps, width, height) -> io.NodeOutput:
|
|
seq_len = (width * height / (16 * 16))
|
|
sigmas = get_schedule(steps, round(seq_len))
|
|
return io.NodeOutput(sigmas)
|
|
|
|
|
|
class FluxExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
CLIPTextEncodeFlux,
|
|
FluxGuidance,
|
|
FluxDisableGuidance,
|
|
FluxKontextImageScale,
|
|
FluxKontextMultiReferenceLatentMethod,
|
|
EmptyFlux2LatentImage,
|
|
Flux2Scheduler,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> FluxExtension:
|
|
return FluxExtension()
|