mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Compare commits
6 Commits
84bb611351
...
ab14717963
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab14717963 | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
de408d5f2c | ||
|
|
90ba86a95f | ||
|
|
417dbcb32f |
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -26,6 +26,7 @@ import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -338,8 +339,11 @@ try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
torch.backends.cudnn.enabled = os.environ.get("TORCH_AMD_CUDNN_ENABLED", "0").strip().lower() not in {
|
||||
"0", "off", "false", "disable", "disabled", "no"}
|
||||
if not torch.backends.cudnn.enabled:
|
||||
logging.info(
|
||||
"ComfyUI has set torch.backends.cudnn.enabled to False for better AMD performance. Set environment var TORCH_AMD_CUDNN_ENABLED=1 to enable it again.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
|
||||
@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
|
||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSEEDS2(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||
sampler_name = "seeds_2"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
SamplerDPMAdaptative,
|
||||
SamplerER_SDE,
|
||||
SamplerSASolver,
|
||||
SamplerSEEDS2,
|
||||
SplitSigmas,
|
||||
SplitSigmasDenoise,
|
||||
FlipSigmas,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user