diff --git a/README.md b/README.md index 62800bb4f..0f39cfce2 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, ## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). - ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Image Models @@ -99,6 +98,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) +## Release Process + +ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories: + +1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** + - Releases a new stable version (e.g., v0.7.0) + - Serves as the foundation for the desktop release + +2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** + - Builds a new release using the latest stable core version + - Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0) + +3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** + - Weekly frontend updates are merged into the core repository + - Features are frozen for the upcoming core release + - Development continues for the next release cycle + ## Shortcuts | Keybind | Explanation | @@ -149,8 +165,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you If you have trouble extracting it, right click the file -> properties -> unblock -If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643) - #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py index 42b0d75ba..27d85d9ce 100644 --- a/app/custom_node_manager.py +++ b/app/custom_node_manager.py @@ -93,16 +93,20 @@ class CustomNodeManager: def add_routes(self, routes, webapp, loadedModules): + example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"] + @routes.get("/workflow_templates") async def get_workflow_templates(request): """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted.""" - files = [ - file - for folder in folder_paths.get_folder_paths("custom_nodes") - for file in glob.glob( - os.path.join(folder, "*/example_workflows/*.json") - ) - ] + + files = [] + + for folder in folder_paths.get_folder_paths("custom_nodes"): + for folder_name in example_workflow_folder_names: + pattern = os.path.join(folder, f"*/{folder_name}/*.json") + matched_files = glob.glob(pattern) + files.extend(matched_files) + workflow_templates_dict = ( {} ) # custom_nodes folder name -> example workflow names @@ -118,15 +122,22 @@ class CustomNodeManager: # Serve workflow templates from custom nodes. for module_name, module_dir in loadedModules: - workflows_dir = os.path.join(module_dir, "example_workflows") - if os.path.exists(workflows_dir): - webapp.add_routes( - [ - web.static( - "/api/workflow_templates/" + module_name, workflows_dir - ) - ] - ) + for folder_name in example_workflow_folder_names: + workflows_dir = os.path.join(module_dir, folder_name) + + if os.path.exists(workflows_dir): + if folder_name != "example_workflows": + logging.warning( + "WARNING: Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'", + folder_name, module_name) + + webapp.add_routes( + [ + web.static( + "/api/workflow_templates/" + module_name, workflows_dir + ) + ] + ) @routes.get("/i18n") async def get_i18n(request): diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 4ceeb3468..2ffc9c021 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -48,6 +48,7 @@ class IO(StrEnum): FACE_ANALYSIS = "FACE_ANALYSIS" BBOX = "BBOX" SEGS = "SEGS" + VIDEO = "VIDEO" ANY = "*" """Always matches any type, but at a price. @@ -273,7 +274,7 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing """ - OUTPUT_IS_LIST: tuple[bool] + OUTPUT_IS_LIST: tuple[bool, ...] """A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items. Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list. @@ -292,7 +293,7 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing """ - RETURN_TYPES: tuple[IO] + RETURN_TYPES: tuple[IO, ...] """A tuple representing the outputs of this node. Usage:: @@ -301,12 +302,12 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types """ - RETURN_NAMES: tuple[str] + RETURN_NAMES: tuple[str, ...] """The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")`` Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names """ - OUTPUT_TOOLTIPS: tuple[str] + OUTPUT_TOOLTIPS: tuple[str, ...] """A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`.""" FUNCTION: str """The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"` diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6388d3faf..77ef748e8 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1345,28 +1345,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) @torch.no_grad() -def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): +def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False): """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_d = None + uncond_denoised = None + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] + + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - d = to_d(x, sigmas[i], denoised) + if cfg_pp: + d = to_d(x, sigmas[i], uncond_denoised) + else: + d = to_d(x, sigmas[i], denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) dt = sigmas[i + 1] - sigmas[i] if i == 0: # Euler method - x = x + d * dt + if cfg_pp: + x = denoised + d * sigmas[i + 1] + else: + x = x + d * dt else: # Gradient estimation - d_bar = ge_gamma * d + (1 - ge_gamma) * old_d - x = x + d_bar * dt + if cfg_pp: + d_bar = (ge_gamma - 1) * (d - old_d) + x = denoised + d * sigmas[i + 1] + d_bar * dt + else: + d_bar = ge_gamma * d + (1 - ge_gamma) * old_d + x = x + d_bar * dt old_d = d return x +@torch.no_grad() +def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): + return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) + @torch.no_grad() def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3): """ diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index fcb5a9c51..0305747bf 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module): y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, encoder_hidden_states_llama3=None, + image_cond=None, control = None, transformer_options = {}, ) -> torch.Tensor: bs, c, h, w = x.shape + if image_cond is not None: + x = torch.cat([x, image_cond], dim=-1) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) timesteps = t pooled_embeds = y diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b8eec3afb..66bee7480 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -631,6 +631,7 @@ class VaceWanModel(WanModel): if ii is not None: c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) x += c_skip * vace_strength + del c_skip # head x = self.head(x, e) diff --git a/comfy/model_base.py b/comfy/model_base.py index b0c6a465b..d2aa4ce7a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1104,4 +1104,7 @@ class HiDream(BaseModel): conditioning_llama3 = kwargs.get("conditioning_llama3", None) if conditioning_llama3 is not None: out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + image_cond = kwargs.get("concat_latent_image", None) + if image_cond is not None: + out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out diff --git a/comfy/model_management.py b/comfy/model_management.py index 516b6e2f1..44aff3762 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -946,9 +946,9 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) -stream_counter = 0 +stream_counters = {} def get_offload_stream(device): - global stream_counter + stream_counter = stream_counters.get(device, 0) if NUM_STREAMS <= 1: return None @@ -958,14 +958,16 @@ def get_offload_stream(device): stream_counter = (stream_counter + 1) % len(ss) if is_device_cuda(device): ss[stream_counter].wait_stream(torch.cuda.current_stream()) + stream_counters[device] = stream_counter return s elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=10)) + ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] stream_counter = (stream_counter + 1) % len(ss) + stream_counters[device] = stream_counter return s return None diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b79af1e92..7e7291476 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module): self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end + self.zsnr = zsnr # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - if zsnr: + if self.zsnr: sigmas = rescale_zero_terminal_snr_sigmas(sigmas) self.set_sigmas(sigmas) diff --git a/comfy/samplers.py b/comfy/samplers.py index 27dfce45a..67ae09a25 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "er_sde", "seeds_2", "seeds_3"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5e55035cf..69bcee1f7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -993,6 +993,10 @@ class WAN21_Vace(WAN21_T2V): "model_type": "vace", } + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 1.2 * self.memory_usage_factor + def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py new file mode 100644 index 000000000..66667946f --- /dev/null +++ b/comfy_api/input/__init__.py @@ -0,0 +1,8 @@ +from .basic_types import ImageInput, AudioInput +from .video_types import VideoInput + +__all__ = [ + "ImageInput", + "AudioInput", + "VideoInput", +] diff --git a/comfy_api/input/basic_types.py b/comfy_api/input/basic_types.py new file mode 100644 index 000000000..033fb7e27 --- /dev/null +++ b/comfy_api/input/basic_types.py @@ -0,0 +1,20 @@ +import torch +from typing import TypedDict + +ImageInput = torch.Tensor +""" +An image in format [B, H, W, C] where B is the batch size, C is the number of channels, +""" + +class AudioInput(TypedDict): + """ + TypedDict representing audio input. + """ + + waveform: torch.Tensor + """ + Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, + """ + + sample_rate: int + diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py new file mode 100644 index 000000000..0676e0e66 --- /dev/null +++ b/comfy_api/input/video_types.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Optional +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents + +class VideoInput(ABC): + """ + Abstract base class for video input types. + """ + + @abstractmethod + def get_components(self) -> VideoComponents: + """ + Abstract method to get the video components (images, audio, and frame rate). + + Returns: + VideoComponents containing images, audio, and frame rate + """ + pass + + @abstractmethod + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + """ + Abstract method to save the video input to a file. + """ + pass + + # Provide a default implementation, but subclasses can provide optimized versions + # if possible. + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + components = self.get_components() + return components.images.shape[2], components.images.shape[1] + diff --git a/comfy_api/input_impl/__init__.py b/comfy_api/input_impl/__init__.py new file mode 100644 index 000000000..02901b8b9 --- /dev/null +++ b/comfy_api/input_impl/__init__.py @@ -0,0 +1,7 @@ +from .video_types import VideoFromFile, VideoFromComponents + +__all__ = [ + # Implementations + "VideoFromFile", + "VideoFromComponents", +] diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py new file mode 100644 index 000000000..12e5783db --- /dev/null +++ b/comfy_api/input_impl/video_types.py @@ -0,0 +1,224 @@ +from __future__ import annotations +from av.container import InputContainer +from av.subtitles.stream import SubtitleStream +from fractions import Fraction +from typing import Optional +from comfy_api.input import AudioInput +import av +import io +import json +import numpy as np +import torch +from comfy_api.input import VideoInput +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents + +class VideoFromFile(VideoInput): + """ + Class representing video input from a file. + """ + + def __init__(self, file: str | io.BytesIO): + """ + Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object + containing the file contents. + """ + self.__file = file + + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + for stream in container.streams: + if stream.type == 'video': + assert isinstance(stream, av.VideoStream) + return stream.width, stream.height + raise ValueError(f"No video stream found in file '{self.__file}'") + + def get_components_internal(self, container: InputContainer) -> VideoComponents: + # Get video frames + frames = [] + for frame in container.decode(video=0): + img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) + img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) + frames.append(img) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + + # Get frame rate + video_stream = next(s for s in container.streams if s.type == 'video') + frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + + # Get audio if available + audio = None + try: + container.seek(0) # Reset the container to the beginning + for stream in container.streams: + if stream.type != 'audio': + continue + assert isinstance(stream, av.AudioStream) + audio_frames = [] + for packet in container.demux(stream): + for frame in packet.decode(): + assert isinstance(frame, av.AudioFrame) + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, + }) + except StopIteration: + pass # No audio stream + + metadata = container.metadata + return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + + def get_components(self) -> VideoComponents: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + return self.get_components_internal(container) + raise ValueError(f"No video stream found in file '{self.__file}'") + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + container_format = container.format.name + video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + reuse_streams = True + if format != VideoContainer.AUTO and format not in container_format.split(","): + reuse_streams = False + if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: + reuse_streams = False + + if not reuse_streams: + components = self.get_components_internal(container) + video = VideoFromComponents(components) + return video.save_to( + path, + format=format, + codec=codec, + metadata=metadata + ) + + streams = container.streams + with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container: + # Copy over the original metadata + for key, value in container.metadata.items(): + if metadata is None or key not in metadata: + output_container.metadata[key] = value + + # Add our new metadata + if metadata is not None: + for key, value in metadata.items(): + if isinstance(value, str): + output_container.metadata[key] = value + else: + output_container.metadata[key] = json.dumps(value) + + # Add streams to the new container + stream_map = {} + for stream in streams: + if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): + out_stream = output_container.add_stream_from_template(template=stream, opaque=True) + stream_map[stream] = out_stream + + # Write packets to the new container + for packet in container.demux(): + if packet.stream in stream_map and packet.dts is not None: + packet.stream = stream_map[packet.stream] + output_container.mux(packet) + +class VideoFromComponents(VideoInput): + """ + Class representing video input from tensors. + """ + + def __init__(self, components: VideoComponents): + self.__components = components + + def get_components(self) -> VideoComponents: + return VideoComponents( + images=self.__components.images, + audio=self.__components.audio, + frame_rate=self.__components.frame_rate + ) + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if format != VideoContainer.AUTO and format != VideoContainer.MP4: + raise ValueError("Only MP4 format is supported for now") + if codec != VideoCodec.AUTO and codec != VideoCodec.H264: + raise ValueError("Only H264 codec is supported for now") + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + # Add metadata before writing any streams + if metadata is not None: + for key, value in metadata.items(): + output.metadata[key] = json.dumps(value) + + frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) + # Create a video stream + video_stream = output.add_stream('h264', rate=frame_rate) + video_stream.width = self.__components.images.shape[2] + video_stream.height = self.__components.images.shape[1] + video_stream.pix_fmt = 'yuv420p' + + # Create an audio stream + audio_sample_rate = 1 + audio_stream: Optional[av.AudioStream] = None + if self.__components.audio: + audio_sample_rate = int(self.__components.audio['sample_rate']) + audio_stream = output.add_stream('aac', rate=audio_sample_rate) + audio_stream.sample_rate = audio_sample_rate + audio_stream.format = 'fltp' + + # Encode video + for i, frame in enumerate(self.__components.images): + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + packet = video_stream.encode(frame) + output.mux(packet) + + # Flush video + packet = video_stream.encode(None) + output.mux(packet) + + if audio_stream and self.__components.audio: + # Encode audio + samples_per_frame = int(audio_sample_rate / frame_rate) + num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame + for i in range(num_frames): + start = i * samples_per_frame + end = start + samples_per_frame + # TODO(Feature) - Add support for stereo audio + chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy() + audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') + audio_frame.sample_rate = audio_sample_rate + audio_frame.pts = i * samples_per_frame + for packet in audio_stream.encode(audio_frame): + output.mux(packet) + + # Flush audio + for packet in audio_stream.encode(None): + output.mux(packet) + diff --git a/comfy_api/util/__init__.py b/comfy_api/util/__init__.py new file mode 100644 index 000000000..9019c46db --- /dev/null +++ b/comfy_api/util/__init__.py @@ -0,0 +1,8 @@ +from .video_types import VideoContainer, VideoCodec, VideoComponents + +__all__ = [ + # Utility Types + "VideoContainer", + "VideoCodec", + "VideoComponents", +] diff --git a/comfy_api/util/video_types.py b/comfy_api/util/video_types.py new file mode 100644 index 000000000..d09663db9 --- /dev/null +++ b/comfy_api/util/video_types.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from fractions import Fraction +from typing import Optional +from comfy_api.input import ImageInput, AudioInput + +class VideoCodec(str, Enum): + AUTO = "auto" + H264 = "h264" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of codec names that can be used as node input. + """ + return [member.value for member in cls] + +class VideoContainer(str, Enum): + AUTO = "auto" + MP4 = "mp4" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of container names that can be used as node input. + """ + return [member.value for member in cls] + + @classmethod + def get_extension(cls, value) -> str: + """ + Returns the file extension for the container. + """ + if isinstance(value, str): + value = cls(value) + if value == VideoContainer.MP4 or value == VideoContainer.AUTO: + return "mp4" + return "" + +@dataclass +class VideoComponents: + """ + Dataclass representing the components of a video. + """ + + images: ImageInput + frame_rate: Fraction + audio: Optional[AudioInput] = None + metadata: Optional[dict] = None + diff --git a/comfy_api_nodes/nodes_api.py b/comfy_api_nodes/nodes_api.py index 4105ba7e1..a977bb9b7 100644 --- a/comfy_api_nodes/nodes_api.py +++ b/comfy_api_nodes/nodes_api.py @@ -1,21 +1,22 @@ +import base64 import io +import math from inspect import cleandoc -from comfy.utils import common_upscale +import numpy as np +import requests +import torch +from PIL import Image + from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict +from comfy.utils import common_upscale from comfy_api_nodes.apis import ( - OpenAIImageGenerationRequest, OpenAIImageEditRequest, - OpenAIImageGenerationResponse + OpenAIImageGenerationRequest, + OpenAIImageGenerationResponse, ) from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation -import numpy as np -from PIL import Image -import requests -import torch -import math -import base64 def downscale_input(image): samples = image.movedim(-1,1) @@ -331,6 +332,11 @@ class OpenAIGPTImage1(ComfyNodeABC): "default": None, "tooltip": "Optional mask for inpainting (white areas will be replaced)", }), + "moderation": (IO.COMBO, { + "options": ["low","auto"], + "default": "low", + "tooltip": "Moderation level", + }), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG" @@ -343,7 +349,7 @@ class OpenAIGPTImage1(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None): + def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"): model = "gpt-image-1" path = "/proxy/openai/images/generations" request_class = OpenAIImageGenerationRequest @@ -415,6 +421,7 @@ class OpenAIGPTImage1(ComfyNodeABC): n=n, seed=seed, size=size, + moderation=moderation, ), files=files if files else None, auth_token=auth_token diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index ff3fe5cdc..e6dc122ca 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -38,6 +38,7 @@ class LTXVImgToVideo: "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -46,7 +47,7 @@ class LTXVImgToVideo: CATEGORY = "conditioning/video_models" FUNCTION = "generate" - def generate(self, positive, negative, image, vae, width, height, length, batch_size): + def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) @@ -59,7 +60,7 @@ class LTXVImgToVideo: dtype=torch.float32, device=latent.device, ) - conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0 + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) @@ -152,6 +153,15 @@ class LTXVAddGuide: return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + _, latent_idx = self.get_latent_index( + cond=positive, + latent_length=latent_image.shape[2], + guide_length=guiding_latent.shape[2], + frame_idx=frame_idx, + scale_factors=scale_factors, + ) + noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ccf601158..78d284889 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: metadata["modelspec.predict_key"] = "v" + extra_keys["v_pred"] = torch.tensor([]) + if getattr(model_sampling, "zsnr", False): + extra_keys["ztsnr"] = torch.tensor([]) if not args.disable_metadata: metadata["prompt"] = prompt_info diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index a9e244ebe..61f7171b2 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,9 +5,13 @@ import av import torch import folder_paths import json +from typing import Optional, Literal from fractions import Fraction -from comfy.comfy_types import FileLocator - +from comfy.comfy_types import IO, FileLocator, ComfyNodeABC +from comfy_api.input import ImageInput, AudioInput, VideoInput +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from comfy_api.input_impl import VideoFromFile, VideoFromComponents +from comfy.cli_args import args class SaveWEBM: def __init__(self): @@ -75,7 +79,163 @@ class SaveWEBM: return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side +class SaveVideo(ComfyNodeABC): + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type: Literal["output"] = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to save."}), + "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), + "format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), + "codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = () + FUNCTION = "save_video" + + OUTPUT_NODE = True + + CATEGORY = "image/video" + DESCRIPTION = "Saves the input images to your ComfyUI output directory." + + def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + width, height = video.get_dimensions() + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, + self.output_dir, + width, + height + ) + results: list[FileLocator] = list() + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo) + if prompt is not None: + metadata["prompt"] = prompt + if len(metadata) > 0: + saved_metadata = metadata + file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + video.save_to( + os.path.join(full_output_folder, file), + format=format, + codec=codec, + metadata=saved_metadata + ) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + return { "ui": { "images": results, "animated": (True,) } } + +class CreateVideo(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": (IO.IMAGE, {"tooltip": "The images to create a video from."}), + "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}), + }, + "optional": { + "audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}), + } + } + + RETURN_TYPES = (IO.VIDEO,) + FUNCTION = "create_video" + + CATEGORY = "image/video" + DESCRIPTION = "Create a video from images." + + def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None): + return (VideoFromComponents( + VideoComponents( + images=images, + audio=audio, + frame_rate=Fraction(fps), + ) + ),) + +class GetVideoComponents(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to extract components from."}), + } + } + RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT) + RETURN_NAMES = ("images", "audio", "fps") + FUNCTION = "get_components" + + CATEGORY = "image/video" + DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." + + def get_components(self, video: VideoInput): + components = video.get_components() + + return (components.images, components.audio, float(components.frame_rate)) + +class LoadVideo(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["video"]) + return {"required": + {"file": (sorted(files), {"video_upload": True})}, + } + + CATEGORY = "image/video" + + RETURN_TYPES = (IO.VIDEO,) + FUNCTION = "load_video" + def load_video(self, file): + video_path = folder_paths.get_annotated_filepath(file) + return (VideoFromFile(video_path),) + + @classmethod + def IS_CHANGED(cls, file): + video_path = folder_paths.get_annotated_filepath(file) + mod_time = os.path.getmtime(video_path) + # Instead of hashing the file, we can just use the modification time to avoid + # rehashing large files. + return mod_time + + @classmethod + def VALIDATE_INPUTS(cls, file): + if not folder_paths.exists_annotated_filepath(file): + return "Invalid video file: {}".format(file) + + return True NODE_CLASS_MAPPINGS = { "SaveWEBM": SaveWEBM, + "SaveVideo": SaveVideo, + "CreateVideo": CreateVideo, + "GetVideoComponents": GetVideoComponents, + "LoadVideo": LoadVideo, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SaveVideo": "Save Video", + "CreateVideo": "Create Video", + "GetVideoComponents": "Get Video Components", + "LoadVideo": "Load Video", } diff --git a/folder_paths.py b/folder_paths.py index 9a525e5a1..f0b3fd103 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -4,7 +4,7 @@ import os import time import mimetypes import logging -from typing import Literal +from typing import Literal, List from collections.abc import Collection from comfy.cli_args import args @@ -141,7 +141,7 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) diff --git a/requirements.txt b/requirements.txt index 95532667a..9c64996d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,5 +23,5 @@ psutil kornia>=0.7.1 spandrel soundfile -av>=14.1.0 +av>=14.2.0 pydantic~=2.0