Compare commits

...

5 Commits

Author SHA1 Message Date
kabachuha
c3286ff17a
Merge a7e8008fdc into 6592bffc60 2025-12-14 09:31:17 +01:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
comfyanonymous
971cefe7d4
Fix pytorch warnings. (#11314)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-13 18:45:23 -05:00
kabachuha
a7e8008fdc hunyuan 1.5 flf: remove spaces from empty lines 2025-12-11 20:49:45 +03:00
kabachuha
1e4f506f03 Initial HYV1.5 manual FLF implementation 2025-12-06 17:25:47 +03:00
5 changed files with 122 additions and 5 deletions

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):

View File

@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,

View File

@ -124,6 +124,87 @@ class HunyuanVideo15ImageToVideo(io.ComfyNode):
return io.NodeOutput(positive, negative, out_latent)
class HunyuanVideo15FirstLastFrameToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15FirstLastFrameToVideo",
category="conditioning/video_models",
is_experimental=True,
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16],
device=comfy.model_management.intermediate_device())
concat_latent_image = torch.zeros((batch_size, 32, latent.shape[2], latent.shape[3], latent.shape[4]),
device=comfy.model_management.intermediate_device())
mask = torch.ones((1, 1, latent.shape[2], latent.shape[3], latent.shape[4]),
device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded_start = vae.encode(start_image[:, :, :, :3])
concat_latent_image[:, :, :encoded_start.shape[2], :, :] = encoded_start
start_frames_in_latent = ((start_image.shape[0] - 1) // 4) + 1
mask[:, :, :start_frames_in_latent] = 0.0
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded_end = vae.encode(end_image[:, :, :, :3])
end_frames_in_latent = ((end_image.shape[0] - 1) // 4) + 1
concat_latent_image[:, :, -end_frames_in_latent:, :, :] = encoded_end[:, :, -end_frames_in_latent:, :, :]
mask[:, :, -end_frames_in_latent:] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
clip_vision_output = None
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image
if clip_vision_end_image is not None:
if clip_vision_output is not None:
pass # Use only one embedding for now
else:
clip_vision_output = clip_vision_end_image
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class HunyuanVideo15SuperResolution(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -406,6 +487,7 @@ class HunyuanExtension(ComfyExtension):
EmptyHunyuanLatentVideo,
EmptyHunyuanVideo15Latent,
HunyuanVideo15ImageToVideo,
HunyuanVideo15FirstLastFrameToVideo,
HunyuanVideo15SuperResolution,
HunyuanVideo15LatentUpscaleWithModel,
LatentUpscaleModelLoader,