Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-12-23 10:33:06 +03:00 committed by GitHub
commit 403a081215
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 13 deletions

View File

@ -378,7 +378,7 @@ class Decoder(nn.Module):
assert ( assert (
timestep is not None timestep is not None
), "should pass timestep with timestep_conditioning=True" ), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks: for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
@ -403,7 +403,7 @@ class Decoder(nn.Module):
) )
ada_values = self.last_scale_shift_table[ ada_values = self.last_scale_shift_table[
None, ..., None, None, None None, ..., None, None, None
] + embedded_timestep.reshape( ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
batch_size, batch_size,
2, 2,
-1, -1,
@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module):
), "should pass timestep with timestep_conditioning=True" ), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[ ada_values = self.scale_shift_table[
None, ..., None, None, None None, ..., None, None, None
] + timestep.reshape( ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
batch_size, batch_size,
4, 4,
-1, -1,
@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module):
if self.inject_noise: if self.inject_noise:
hidden_states = self._feed_spatial_noise( hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1 hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
) )
hidden_states = self.norm2(hidden_states) hidden_states = self.norm2(hidden_states)
@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module):
if self.inject_noise: if self.inject_noise:
hidden_states = self._feed_spatial_noise( hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2 hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
) )
input_tensor = self.norm3(input_tensor) input_tensor = self.norm3(input_tensor)

View File

@ -6,16 +6,15 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy import model_management
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
if model_management.xformers_enabled(): # if model_management.xformers_enabled():
import xformers.ops # import xformers.ops
if int((xformers.__version__).split(".")[2]) >= 28: # if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens # block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
else: # else:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens # block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
def modulate(x, shift, scale): def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

View File

@ -534,7 +534,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024 lowvram_model_memory = 0.1
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)