Compare commits

...

12 Commits

Author SHA1 Message Date
Jukka Seppänen
98cad6eca7
Merge 1e76c3b9c9 into 65045730a6 2026-05-09 00:47:21 +03:00
Jukka Seppänen
1e76c3b9c9
Merge branch 'master' into ltxv_self_attn_mask 2026-05-09 00:47:19 +03:00
Alexander Piskun
65045730a6
[Partner Nodes] additionally use Baidu server to detect the accessibility of internet (#13803)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-08 13:11:52 -07:00
Matt Miller
87878f354f
Add cloud-runtime FE-facing operations to spec (#13734)
* Add cloud-runtime FE-facing operations to openapi.yaml

Add ~67 cloud-runtime FE-facing path operations to the core OpenAPI spec,
each tagged with x-runtime: [cloud] at the operation level. These operations
are served by the cloud runtime; the local runtime returns 404 for all of
these paths.

Domain groups added:
- Jobs / prompts: /api/job/*, /api/jobs/*/cancel, /api/prompt/*, etc.
- History v2: /api/history_v2, /api/history_v2/{prompt_id}
- Cloud logs: /api/logs
- Asset extensions: /api/assets/download, export, import, etc.
- Custom nodes: /api/experiment/nodes (cloud install/uninstall)
- Hub: /api/hub/profiles, /api/hub/workflows, /api/hub/labels, etc.
- Workflows: /api/workflows CRUD, versioning, fork, publish
- Auth/session: /api/auth/session, /api/auth/token, /.well-known/jwks.json
- Billing: /api/billing/balance, plans, subscribe, topup, etc.
- Workspace: /api/workspace/*, /api/workspaces/*
- User/settings/misc: /api/user, /api/secrets, /api/feedback, etc.

Also adds corresponding cloud-only component schemas (CloudJob, CloudWorkflow,
BillingPlan, Workspace, HubProfile, AuthSession, etc.), all tagged with
x-runtime: [cloud].

Spectral lint passes under the existing ruleset with zero new warnings.

* Add job_id field to Asset schema and deprecate prompt_id (#13736)

- Add job_id as a nullable UUID field to the Asset schema
- Mark prompt_id as deprecated with note pointing to job_id
- No x-runtime tag needed as both runtimes populate the field

* Add hash field to Asset schemas and deprecate asset_hash (#13738)

- Add 'hash' as a nullable string field to Asset and AssetUpdated schemas
- Mark 'asset_hash' as deprecated with a note pointing to 'hash'
- AssetCreated inherits 'hash' via allOf from Asset
- Spectral lint clean (no new warnings)

* Fix method drift on cloud-runtime endpoints

Three PUT operations were added that should be PATCH (cloud serves
PATCH for partial updates):

- /api/workflows/{workflow_id}
- /api/workspaces/{id}
- /api/workspace/members/{userId}

Two POST operations were added that should be GET (cloud serves GET
with query params):

- /api/assets/remote-metadata (url moves to query param)
- /api/files/mask-layers (response shape replaced — operation queries
  related mask layer filenames, not file uploads)

* Add missing cloud-runtime operations and schemas

PR review surfaced operations the cloud runtime serves that weren't
covered by the initial spec push, plus one path family missed entirely.

New methods on existing paths:

- /api/auth/session: add POST (create session cookie) and DELETE (logout)
- /api/secrets/{id}: add GET (read metadata) and PATCH (update)
- /api/hub/profiles: add POST (create profile)
- /api/hub/workflows: add POST (publish to hub)
- /api/hub/workflows/{share_id}: add DELETE (unpublish)
- /api/workspaces/{id}: add DELETE (soft-delete workspace)
- /api/workspace/members/{user_id}/api-keys: add DELETE (bulk revoke)
- /api/workflows/{workflow_id}/versions: add POST (create new version)
- /api/userdata/{file}/publish: add GET (read publish info)

New path family:

- /api/tasks (GET list) and /api/tasks/{task_id} (GET detail) for the
  background task framework

New component schemas (all tagged x-runtime: [cloud]):

CreateSessionResponse, DeleteSessionResponse, UpdateSecretRequest,
BulkRevokeAPIKeysResponse, CreateHubProfileRequest, PublishHubWorkflowRequest,
HubWorkflowDetail, AssetInfo, CreateWorkflowVersionRequest,
WorkflowVersionResponse, WorkflowPublishInfo, TaskEntry, TaskResponse,
TasksListResponse. Existing SecretMeta extended with provider and
last_used_at fields the cloud runtime actually returns.

New tag: task. Spectral lint passes with zero errors.

* Add job_id and prompt_id to AssetUpdated schema

Mirrors the Asset schema's deprecation pattern: prompt_id is marked
deprecated with a description pointing to job_id; job_id is the new
preferred field. PUT /api/assets/{id} responses can now carry both fields
consistent with the other Asset-returning endpoints.

* feat: add width and height fields to Asset schema (#13745)

Add nullable integer fields 'width' and 'height' to the Asset schema
in openapi.yaml. These expose original image dimensions in pixels for
clients that need pre-thumbnail size info. Both fields are null for
non-image assets or assets ingested before dimension extraction.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* Remove /api/job/{job_id} and /api/job/{job_id}/outputs

These two paths are not actually served by the cloud runtime — they
return 404 with a redirect message pointing callers to the canonical
`/api/jobs/{job_id}` (plural). Declaring them with `x-runtime: [cloud]`
and a 200 response schema is incorrect.

`/api/job/{job_id}/status` stays — it is a real cloud-served endpoint.

Also drops the now-orphaned `CloudJob` and `CloudJobOutputs` component
schemas. `CloudJobStatus` is retained.
2026-05-08 12:39:16 -07:00
Alexis Rolland
c5ecd231a2
fix: Fix bug when mask not on same device (CORE-181) (#13801)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-08 23:06:29 +08:00
drozbay
9864f5ac86
fix: Stop LTXVImgToVideoInplace from mutating input latents and dropping noise_mask (#13793) 2026-05-08 23:02:17 +08:00
kijai
f559a749e9 Merge remote-tracking branch 'upstream/master' into ltxv_self_attn_mask 2026-05-07 15:02:56 +03:00
kijai
989dea8c40 Allow strength above 1.0 2026-05-06 23:56:21 +03:00
kijai
848880c3d3 Merge remote-tracking branch 'upstream/master' into ltxv_self_attn_mask 2026-05-06 21:45:41 +03:00
kijai
6b97e3f4cb Only fall to pytorch attention from sage for guide mask 2026-05-06 21:31:49 +03:00
kijai
f2beaa5802 Reduce peak VRAM by handling self_attn_mask more efficiently 2026-05-06 21:08:15 +03:00
kijai
e6e3e6f628 Alternative self_attn_mask
Drastically lower memory use, different effect, for testing
2026-05-06 16:34:01 +03:00
7 changed files with 4851 additions and 104 deletions

View File

@ -47,7 +47,7 @@ class BackgroundRemovalModel():
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid()
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:

View File

@ -22,26 +22,25 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
self.num_frames = n
self.data = tensor
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
self.patches_per_frame = patches_per_frame
self.num_frames = n // patches_per_frame
# All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.num_frames = n
self.data = tensor
def expand(self):
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# Used by compute_prompt_timestep and the audio cross-attention paths.
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,6 +358,63 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def to(self, *args, **kwargs):
new = GuideAttentionMask.__new__(GuideAttentionMask)
new.guide_start = self.guide_start
new.tracked_count = self.tracked_count
new.noisy_mask = self.noisy_mask.to(*args, **kwargs)
new.tracked_mask = self.tracked_mask.to(*args, **kwargs)
return new
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
return out
class CrossAttention(nn.Module):
def __init__(
self,
@ -412,8 +469,10 @@ class CrossAttention(nn.Module):
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
@ -1063,7 +1122,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1099,12 +1160,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_mask = any(
e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
if not needs_mask:
return None
# Build per-guide-token weights for all tracked guide tokens.
@ -1159,16 +1220,11 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
# Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights == 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1234,45 +1290,6 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})

View File

@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
# to correctly detect that Chinese users with working internet do have working internet.
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
async def _probe(url: str) -> bool:
try:
async with session.get(url) as resp:
return resp.status < 500
except (ClientError, OSError, asyncio.TimeoutError):
return False
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
try:
for fut in asyncio.as_completed(probe_tasks):
if await fut:
results["internet_accessible"] = True
break
finally:
for t in probe_tasks:
if not t.done():
t.cancel()
await asyncio.gather(*probe_tasks, return_exceptions=True)
if not results["internet_accessible"]:
return results

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"]
samples = latent["samples"].clone()
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
_, _, _, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -223,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -302,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -322,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)

File diff suppressed because it is too large Load Diff