mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Compare commits
5 Commits
d197be933d
...
61416ff6a3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61416ff6a3 | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
5905513e32 | ||
|
|
a8ea6953ec |
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
|||||||
@ -119,6 +119,9 @@ class JointAttention(nn.Module):
|
|||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if output.dtype == torch.float16:
|
||||||
|
output.div_(4)
|
||||||
|
|
||||||
return self.out(output)
|
return self.out(output)
|
||||||
|
|
||||||
|
|
||||||
@ -175,8 +178,12 @@ class FeedForward(nn.Module):
|
|||||||
def _forward_silu_gating(self, x1, x3):
|
def _forward_silu_gating(self, x1, x3):
|
||||||
return clamp_fp16(F.silu(x1) * x3)
|
return clamp_fp16(F.silu(x1) * x3)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, apply_fp16_downscale=False):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
x3 = self.w3(x)
|
||||||
|
if x.dtype == torch.float16 and apply_fp16_downscale:
|
||||||
|
x3.div_(32)
|
||||||
|
|
||||||
|
return self.w2(self._forward_silu_gating(self.w1(x), x3))
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerBlock(nn.Module):
|
class JointTransformerBlock(nn.Module):
|
||||||
@ -287,6 +294,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
clamp_fp16(self.feed_forward(
|
clamp_fp16(self.feed_forward(
|
||||||
modulate(self.ffn_norm1(x), scale_mlp),
|
modulate(self.ffn_norm1(x), scale_mlp),
|
||||||
|
apply_fp16_downscale=True,
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
|
|||||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerSEEDS2(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerSEEDS2",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||||
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||||
|
sampler_name = "seeds_2"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
|
sampler_name,
|
||||||
|
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
|
SamplerSEEDS2,
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user