mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Compare commits
8 Commits
01aacf89e4
...
229cdd1fa3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
229cdd1fa3 | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
d6cd1c03e3 | ||
|
|
7883076f5c | ||
|
|
483ba1e98b | ||
|
|
648814b751 | ||
|
|
4bdb0dddb7 |
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
|||||||
@ -30,6 +30,13 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from sageattn3 import sageattn3_blackwell
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
|
if (q.device.type != "cuda" or
|
||||||
|
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||||
|
mask is not None):
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
B, H, L, D = q.shape
|
||||||
|
if H != heads:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
q_s, k_s, v_s = q, k, v
|
||||||
|
N = q.shape[2]
|
||||||
|
dim_head = D
|
||||||
|
else:
|
||||||
|
B, N, inner_dim = q.shape
|
||||||
|
if inner_dim % heads != 0:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
dim_head = inner_dim // heads
|
||||||
|
|
||||||
|
if dim_head >= 256 or N <= 1024:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
|
q_s, k_s, v_s = map(
|
||||||
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
|
except Exception as e:
|
||||||
|
exception_fallback = True
|
||||||
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
|
|
||||||
|
if exception_fallback:
|
||||||
|
if not skip_reshape:
|
||||||
|
del q_s, k_s, v_s
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
|
register_attention_function("sage3", attention3_sage)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
|
|||||||
@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
|
|||||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerSEEDS2(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerSEEDS2",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||||
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||||
|
sampler_name = "seeds_2"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
|
sampler_name,
|
||||||
|
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
|
SamplerSEEDS2,
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user