mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 03:40:16 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
0aeb958ea5
@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
|
|||||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
old_d = None
|
old_d = None
|
||||||
|
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
if cfg_pp:
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
d = to_d(x, sigmas[i], denoised)
|
if cfg_pp:
|
||||||
|
d = to_d(x, sigmas[i], uncond_denoised)
|
||||||
|
else:
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Euler method
|
# Euler method
|
||||||
x = x + d * dt
|
if cfg_pp:
|
||||||
|
x = denoised + d * sigmas[i + 1]
|
||||||
|
else:
|
||||||
|
x = x + d * dt
|
||||||
else:
|
else:
|
||||||
# Gradient estimation
|
# Gradient estimation
|
||||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
if cfg_pp:
|
||||||
x = x + d_bar * dt
|
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||||
|
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||||
|
else:
|
||||||
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||||
|
x = x + d_bar * dt
|
||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
|
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states_llama3=None,
|
encoder_hidden_states_llama3=None,
|
||||||
|
image_cond=None,
|
||||||
control = None,
|
control = None,
|
||||||
transformer_options = {},
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
|
if image_cond is not None:
|
||||||
|
x = torch.cat([x, image_cond], dim=-1)
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
timesteps = t
|
timesteps = t
|
||||||
pooled_embeds = y
|
pooled_embeds = y
|
||||||
|
|||||||
@ -1104,4 +1104,7 @@ class HiDream(BaseModel):
|
|||||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||||
if conditioning_llama3 is not None:
|
if conditioning_llama3 is not None:
|
||||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||||
|
image_cond = kwargs.get("concat_latent_image", None)
|
||||||
|
if image_cond is not None:
|
||||||
|
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -963,7 +963,7 @@ def get_offload_stream(device):
|
|||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.cuda.Stream(device=device, priority=10))
|
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
|||||||
@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
|
self.zsnr = zsnr
|
||||||
|
|
||||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
if zsnr:
|
if self.zsnr:
|
||||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||||
|
|
||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
|||||||
@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import math
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
|
|
||||||
from comfy.utils import common_upscale
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from comfy.utils import common_upscale
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
OpenAIImageGenerationRequest,
|
|
||||||
OpenAIImageEditRequest,
|
OpenAIImageEditRequest,
|
||||||
OpenAIImageGenerationResponse
|
OpenAIImageGenerationRequest,
|
||||||
|
OpenAIImageGenerationResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import base64
|
|
||||||
|
|
||||||
def downscale_input(image):
|
def downscale_input(image):
|
||||||
samples = image.movedim(-1,1)
|
samples = image.movedim(-1,1)
|
||||||
@ -331,6 +332,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
"default": None,
|
"default": None,
|
||||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
}),
|
}),
|
||||||
|
"moderation": (IO.COMBO, {
|
||||||
|
"options": ["low","auto"],
|
||||||
|
"default": "low",
|
||||||
|
"tooltip": "Moderation level",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
"hidden": {
|
"hidden": {
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||||
@ -343,7 +349,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"):
|
||||||
model = "gpt-image-1"
|
model = "gpt-image-1"
|
||||||
path = "/proxy/openai/images/generations"
|
path = "/proxy/openai/images/generations"
|
||||||
request_class = OpenAIImageGenerationRequest
|
request_class = OpenAIImageGenerationRequest
|
||||||
@ -415,6 +421,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
n=n,
|
n=n,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
size=size,
|
size=size,
|
||||||
|
moderation=moderation,
|
||||||
),
|
),
|
||||||
files=files if files else None,
|
files=files if files else None,
|
||||||
auth_token=auth_token
|
auth_token=auth_token
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class LTXVImgToVideo:
|
|||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
@ -46,7 +47,7 @@ class LTXVImgToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
|
||||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
@ -59,7 +60,7 @@ class LTXVImgToVideo:
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=latent.device,
|
device=latent.device,
|
||||||
)
|
)
|
||||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||||
|
|
||||||
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||||
|
|
||||||
@ -152,6 +153,15 @@ class LTXVAddGuide:
|
|||||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||||
|
_, latent_idx = self.get_latent_index(
|
||||||
|
cond=positive,
|
||||||
|
latent_length=latent_image.shape[2],
|
||||||
|
guide_length=guiding_latent.shape[2],
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
scale_factors=scale_factors,
|
||||||
|
)
|
||||||
|
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||||
|
|
||||||
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
|
|||||||
@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
metadata["modelspec.predict_key"] = "epsilon"
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
metadata["modelspec.predict_key"] = "v"
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
extra_keys["v_pred"] = torch.tensor([])
|
||||||
|
if getattr(model_sampling, "zsnr", False):
|
||||||
|
extra_keys["ztsnr"] = torch.tensor([])
|
||||||
|
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user