Merge branch 'master' into ben/release-webhook-dispatch-desktop

This commit is contained in:
Jedrzej Kosinski 2026-02-10 21:57:16 -06:00 committed by GitHub
commit 45df7c8fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 696 additions and 128 deletions

View File

@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids): def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None: if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids) out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
else: else:
return text_embeds return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device) return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try: try:
import comfy.quant_ops import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope q_apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1 q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except: except:
logging.warning("No comfy kitchen, using old apply_rope functions.") logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor): apply_rope = _apply_rope
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) apply_rope1 = _apply_rope1
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"] device = kwargs["device"]
if cross_attn is not None: if cross_attn is not None:
if t5xxl_ids is not None: if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None: if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out

View File

@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types(): def get_supported_float8_types():
float8_types = [] float8_types = []
try: try:

View File

@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling, _prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
) )
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@ -21,6 +21,7 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection from comfy.patcher_extension import PatcherInjection
@ -181,18 +182,21 @@ class BypassForwardHook:
) )
return # Already injected return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward # Move adapter weights to compute device (GPU)
device = None # Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
dtype = None dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None: if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None: # Only use dtype if it's a standard float type, not quantized
self._move_adapter_weights_to_device(device, dtype) if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward self.original_forward = self.module.forward
self.module.forward = self._bypass_forward self.module.forward = self._bypass_forward

View File

@ -34,6 +34,21 @@ class VideoInput(ABC):
""" """
pass pass
@abstractmethod
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
Returns:
A new VideoInput, or None if the result would have negative duration
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]: def get_stream_source(self) -> Union[str, io.BytesIO]:
""" """
Get a streamable source for the video. This allows processing without Get a streamable source for the video. This allows processing without

View File

@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput from .._input import AudioInput, VideoInput
import av import av
import io import io
import itertools
import json import json
import numpy as np import numpy as np
import math import math
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",") formats = container_format.split(",")
return formats[0] return formats[0]
def get_open_write_kwargs( def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict: ) -> dict:
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file. Class representing video input from a file.
""" """
def __init__(self, file: str | io.BytesIO): def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
""" """
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents. containing the file contents.
""" """
self.__file = file self.__file = file
self.__start_time = start_time
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO: def get_stream_source(self) -> str | io.BytesIO:
""" """
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns: Returns:
Duration in seconds Duration in seconds
""" """
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) self.__file.seek(0)
with av.open(self.__file, mode="r") as container: with av.open(self.__file, mode="r") as container:
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate: if video_stream and video_stream.average_rate:
frame_count = 0 frame_count = 0
container.seek(0) container.seek(0)
for packet in container.demux(video_stream): frame_iterator = (
for _ in packet.decode(): container.decode(video_stream)
frame_count += 1 if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for packet in frame_iterator:
frame_count += 1
if frame_count > 0: if frame_count > 0:
return float(frame_count / video_stream.average_rate) return float(frame_count / video_stream.average_rate)
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container: with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container) video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available # 1. Prefer the frames field if available and usable
if video_stream.frames and video_stream.frames > 0: if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
return int(video_stream.frames) return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata # 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if ( if (
getattr(video_stream, "duration", None) is not None getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate and video_stream.average_rate
): ):
duration_seconds = float(video_stream.duration * video_stream.time_base) raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0: if estimated_frames > 0:
return estimated_frames return estimated_frames
# 3. Last resort: decode frames and count them (streaming) # 3. Last resort: decode frames and count them (streaming)
frame_count = 0 if self.__start_time < 0:
container.seek(0) start_time = max(self._get_raw_duration() + self.__start_time, 0)
for packet in container.demux(video_stream): else:
for _ in packet.decode(): start_time = self.__start_time
frame_count += 1 frame_count = 1
start_pts = int(start_time / video_stream.time_base)
if frame_count == 0: end_pts = int((start_time + self.__duration) / video_stream.time_base)
raise ValueError(f"Could not determine frame count for file '{self.__file}'") container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for frame in frame_iterator:
if frame.pts >= start_pts:
break
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
return frame_count return frame_count
def get_frame_rate(self) -> Fraction: def get_frame_rate(self) -> Fraction:
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents: def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames # Get video frames
frames = [] frames = []
for frame in container.decode(video=0): start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img) frames.append(img)
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate # Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video') frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available # Get audio if available
audio = None audio = None
try: container.seek(start_pts, stream=video_stream)
container.seek(0) # Reset the container to the beginning # Use last stream for consistency
for stream in container.streams: if len(container.streams.audio):
if stream.type != 'audio': audio_stream = container.streams.audio[-1]
continue audio_frames = []
assert isinstance(stream, av.AudioStream) resample = av.audio.resampler.AudioResampler(format='fltp').resample
audio_frames = [] frames = itertools.chain.from_iterable(
for packet in container.demux(stream): map(resample, container.decode(audio_stream))
for frame in packet.decode(): )
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) has_first_frame = False
if len(audio_frames) > 0: for frame in frames:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) offset_seconds = start_time - frame.pts * audio_stream.time_base
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) to_skip = int(offset_seconds * audio_stream.sample_rate)
audio = AudioInput({ if to_skip < frame.samples:
"waveform": audio_tensor, has_first_frame = True
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, break
}) if has_first_frame:
except StopIteration: audio_frames.append(frame.to_ndarray()[..., to_skip:])
pass # No audio stream
for frame in frames:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
metadata = container.metadata metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO, path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
): ):
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning self.__file.seek(0) # Reset the BytesIO object to the beginning
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams: if not reuse_streams:
components = self.get_components_internal(container) components = self.get_components_internal(container)
video = VideoFromComponents(components) video = VideoFromComponents(components)
return video.save_to( return video.save_to(
path, path, format=format, codec=codec, metadata=metadata
format=format,
codec=codec,
metadata=metadata
) )
streams = container.streams streams = container.streams
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet) output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer): def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None) if len(container.streams.video):
if video_stream is None: return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return trimmed
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents( return VideoComponents(
images=self.__components.images, images=self.__components.images,
audio=self.__components.audio, audio=self.__components.audio,
frame_rate=self.__components.frame_rate frame_rate=self.__components.frame_rate,
) )
def save_to( def save_to(
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str, path: str,
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
): ):
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None audio_stream: Optional[av.AudioStream] = None
if self.__components.audio: if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate']) audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate) waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video # Encode video
for i, frame in enumerate(self.__components.images): for i, frame in enumerate(self.__components.images):
@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet) output.mux(packet)
if audio_stream and self.__components.audio: if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform'] frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate frame.sample_rate = audio_sample_rate
frame.pts = 0 frame.pts = 0
output.mux(audio_stream.encode(frame)) output.mux(audio_stream.encode(frame))
# Flush encoder # Flush encoder
output.mux(audio_stream.encode(None)) output.mux(audio_stream.encode(None))
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

View File

@ -20,10 +20,60 @@ class JobStatus:
# Media types that can be previewed in the frontend # Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'}) PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
# 3D file extensions for preview fallback (no dedicated media_type exists) # 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'}) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
def has_3d_extension(filename: str) -> bool:
lower = filename.lower()
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
def normalize_output_item(item):
"""Normalize a single output list item for the jobs API.
Returns the normalized item, or None to exclude it.
String items with 3D extensions become {filename, type, subfolder} dicts.
"""
if item is None:
return None
if isinstance(item, str):
if has_3d_extension(item):
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
return None
if isinstance(item, dict):
return item
return None
def normalize_outputs(outputs: dict) -> dict:
"""Normalize raw node outputs for the jobs API.
Transforms string 3D filenames into file output dicts and removes
None items. All other items (non-3D strings, dicts, etc.) are
preserved as-is.
"""
normalized = {}
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
normalized[node_id] = node_outputs
continue
normalized_node = {}
for media_type, items in node_outputs.items():
if media_type == 'animated' or not isinstance(items, list):
normalized_node[media_type] = items
continue
normalized_items = []
for item in items:
if item is None:
continue
norm = normalize_output_item(item)
normalized_items.append(norm if norm is not None else item)
normalized_node[media_type] = normalized_items
normalized[node_id] = normalized_node
return normalized
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic. Maintains backwards compatibility with existing logic.
Priority: Priority:
1. media_type is 'images', 'video', or 'audio' 1. media_type is 'images', 'video', 'audio', or '3d'
2. format field starts with 'video/' or 'audio/' 2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb) 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
""" """
if media_type in PREVIEWABLE_MEDIA_TYPES: if media_type in PREVIEWABLE_MEDIA_TYPES:
return True return True
@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
}) })
if include_outputs: if include_outputs:
job['outputs'] = outputs job['outputs'] = normalize_outputs(outputs)
job['execution_status'] = status_info job['execution_status'] = status_info
job['workflow'] = { job['workflow'] = {
'prompt': prompt, 'prompt': prompt,
@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue continue
for item in items: for item in items:
count += 1 normalized = normalize_output_item(item)
if normalized is None:
if not isinstance(item, dict):
continue continue
if preview_output is None and is_previewable(media_type, item): count += 1
if preview_output is not None:
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
enriched = { enriched = {
**item, **normalized,
'nodeId': node_id, 'nodeId': node_id,
'mediaType': media_type
} }
if item.get('type') == 'output': if 'mediaType' not in normalized:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
preview_output = enriched preview_output = enriched
elif fallback_preview is None: elif fallback_preview is None:
fallback_preview = enriched fallback_preview = enriched

View File

@ -4,6 +4,7 @@ import os
import numpy as np import numpy as np
import safetensors import safetensors
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import trange from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
""" """
CFGGuider with modifications for training specific logic CFGGuider with modifications for training specific logic
""" """
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample( def outer_sample(
self, self,
noise, noise,
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape, noise.shape,
self.conds, self.conds,
self.model_options, self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded force_full_load=not self.offloading,
force_offload=self.offloading,
) )
) )
torch.cuda.empty_cache()
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None: if denoise_mask is not None:
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result return result
def patch(m): def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"): if not hasattr(m, "forward"):
return return
org_forward = m.forward org_forward = m.forward
def fwd(args, kwargs): # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
return org_forward(*args, **kwargs) if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs): def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward m.org_forward = org_forward
m.forward = checkpointing_fwd m.forward = checkpointing_fwd
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True, default=True,
tooltip="Use gradient checkpointing for training.", tooltip="Use gradient checkpointing for training.",
), ),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"], options=folder_paths.get_filename_list("loras") + ["[None]"],
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype, lora_dtype,
algorithm, algorithm,
gradient_checkpointing, gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
algorithm = algorithm[0] algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0] gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing # Setup gradient checkpointing
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward( modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model mp.model.diffusion_model, depth=checkpoint_depth
): )
patch(m) logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True [mp], memory_required=1e20, force_full_load=not offloading
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = TrainGuider(mp) guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive) guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop # Run training loop
try: try:
comfy.model_management.in_training = True
_run_training_loop( _run_training_loop(
guider, guider,
train_sampler, train_sampler,
@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res, multi_res,
) )
finally: finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected # Eject bypass hooks if they were injected
if bypass_injections is not None: if bypass_injections is not None:
for injection in bypass_injections: for injection in bypass_injections:
@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m) unpatch(m)
del train_sampler, optimizer del train_sampler, optimizer
# Finalize adapters for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters: for adapter in all_weight_adapters:
adapter.requires_grad_(False) adapter.requires_grad_(False)
del adapter
for param in lora_sd: del all_weight_adapters
lora_sd[param] = lora_sd[param].to(lora_dtype)
# mp in train node is highly specialized for training # mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it # use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):# class LoraModelLoader(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0, max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.", tooltip="How strongly to modify the diffusion model. This value can be negative.",
), ),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
], ],
outputs=[ outputs=[
io.Model.Output( io.Model.Output(
@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
) )
@classmethod @classmethod
def execute(cls, model, lora, strength_model): def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0: if strength_model == 0:
return io.NodeOutput(model) return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models( if bypass:
model, None, lora, strength_model, 0 model_lora, _ = comfy.sd.load_bypass_lora_for_models(
) model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora) return io.NodeOutput(model_lora)

View File

@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
return True return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension): class VideoExtension(ComfyExtension):
@override @override
@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
CreateVideo, CreateVideo,
GetVideoComponents, GetVideoComponents,
LoadVideo, LoadVideo,
VideoSlice,
] ]
async def comfy_entrypoint() -> VideoExtension: async def comfy_entrypoint() -> VideoExtension:

View File

@ -5,8 +5,11 @@ from comfy_execution.jobs import (
is_previewable, is_previewable,
normalize_queue_item, normalize_queue_item,
normalize_history_item, normalize_history_item,
normalize_output_item,
normalize_outputs,
get_outputs_summary, get_outputs_summary,
apply_sorting, apply_sorting,
has_3d_extension,
) )
@ -35,8 +38,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()""" """Unit tests for is_previewable()"""
def test_previewable_media_types(self): def test_previewable_media_types(self):
"""Images, video, audio media types should be previewable.""" """Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio']: for media_type in ['images', 'video', 'audio', '3d']:
assert is_previewable(media_type, {}) is True assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self): def test_non_previewable_media_types(self):
@ -46,7 +49,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self): def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type.""" """3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']: for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
item = {'filename': f'model{ext}'} item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True assert is_previewable('files', item) is True
@ -160,7 +163,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self): def test_3d_files_previewable(self):
"""3D file extensions should be previewable.""" """3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']: for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
outputs = { outputs = {
'node1': { 'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}] 'files': [{'filename': f'model{ext}', 'type': 'output'}]
@ -192,6 +195,64 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images' assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs' assert preview['subfolder'] == 'outputs'
def test_string_3d_filename_creates_preview(self):
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
Only the .glb counts nulls and non-file strings are excluded."""
outputs = {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
assert preview is not None
assert preview['filename'] == 'preview3d_abc123.glb'
assert preview['mediaType'] == '3d'
assert preview['nodeId'] == 'node1'
assert preview['type'] == 'output'
def test_string_non_3d_filename_no_preview(self):
"""String items without 3D extensions should not create a preview."""
outputs = {
'node1': {
'result': ['data.json', None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 0
assert preview is None
def test_string_3d_filename_used_as_fallback(self):
"""String 3D preview should be used when no dict items are previewable."""
outputs = {
'node1': {
'latents': [{'filename': 'latent.safetensors'}],
},
'node2': {
'result': ['model.glb', None]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None
assert preview['filename'] == 'model.glb'
assert preview['mediaType'] == '3d'
class TestHas3DExtension:
"""Unit tests for has_3d_extension()"""
def test_recognized_extensions(self):
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
assert has_3d_extension(f'model{ext}') is True
def test_case_insensitive(self):
assert has_3d_extension('MODEL.GLB') is True
assert has_3d_extension('Scene.GLTF') is True
def test_non_3d_extensions(self):
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
assert has_3d_extension(name) is False
class TestApplySorting: class TestApplySorting:
"""Unit tests for apply_sorting()""" """Unit tests for apply_sorting()"""
@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}}, 'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'}, 'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
} }
def test_include_outputs_normalizes_3d_strings(self):
"""Detail view should transform string 3D filenames into file output dicts."""
history_item = {
'prompt': (
5,
'prompt-3d',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
},
}
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
assert job['outputs_count'] == 1
result_items = job['outputs']['node1']['result']
assert len(result_items) == 1
assert result_items[0] == {
'filename': 'preview3d_abc123.glb',
'type': 'output',
'subfolder': '',
'mediaType': '3d',
}
def test_include_outputs_preserves_dict_items(self):
"""Detail view normalization should pass dict items through unchanged."""
history_item = {
'prompt': (
5,
'prompt-img',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'images': [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
}
},
}
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
assert job['outputs_count'] == 1
assert job['outputs']['node1']['images'] == [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
class TestNormalizeOutputItem:
"""Unit tests for normalize_output_item()"""
def test_none_returns_none(self):
assert normalize_output_item(None) is None
def test_string_3d_extension_synthesizes_dict(self):
result = normalize_output_item('model.glb')
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
def test_string_non_3d_extension_returns_none(self):
assert normalize_output_item('data.json') is None
def test_string_no_extension_returns_none(self):
assert normalize_output_item('camera_info_string') is None
def test_dict_passes_through(self):
item = {'filename': 'test.png', 'type': 'output'}
assert normalize_output_item(item) is item
def test_other_types_return_none(self):
assert normalize_output_item(42) is None
assert normalize_output_item(True) is None
class TestNormalizeOutputs:
"""Unit tests for normalize_outputs()"""
def test_empty_outputs(self):
assert normalize_outputs({}) == {}
def test_dict_items_pass_through(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
}
}
result = normalize_outputs(outputs)
assert result == outputs
def test_3d_string_synthesized(self):
outputs = {
'node1': {
'result': ['model.glb', None, None],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': [
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
],
}
}
def test_animated_key_preserved(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
'animated': [True],
}
}
result = normalize_outputs(outputs)
assert result['node1']['animated'] == [True]
def test_non_dict_node_outputs_preserved(self):
outputs = {'node1': 'unexpected_value'}
result = normalize_outputs(outputs)
assert result == {'node1': 'unexpected_value'}
def test_none_items_filtered_but_other_types_preserved(self):
outputs = {
'node1': {
'result': ['data.json', None, [1, 2, 3]],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': ['data.json', [1, 2, 3]],
}
}