Compare commits

...

3 Commits

Author SHA1 Message Date
Alexander Piskun
12398306c4
Merge 2549498817 into 6592bffc60 2025-12-14 13:50:06 +08:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
bigcat88
2549498817
api-nodes: use new custom endpoint for Nano Banana 2025-12-13 20:00:56 +02:00
3 changed files with 53 additions and 12 deletions

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
if part_type == "text" and part.text:
parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini-image/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini-image/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,