mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Basic implementation of z image fun control union 2.0 (#11304)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again.
This commit is contained in:
parent
c5a47a1692
commit
da2bfb5b0a
@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
|
n_control_layers=6,
|
||||||
|
control_in_dim=16,
|
||||||
|
additional_in_dim=0,
|
||||||
|
broken=False,
|
||||||
|
refiner_control=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.additional_in_dim = 0
|
self.broken = broken
|
||||||
self.control_in_dim = 16
|
self.additional_in_dim = additional_in_dim
|
||||||
|
self.control_in_dim = control_in_dim
|
||||||
n_refiner_layers = 2
|
n_refiner_layers = 2
|
||||||
self.n_control_layers = 6
|
self.n_control_layers = n_control_layers
|
||||||
self.control_layers = nn.ModuleList(
|
self.control_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ZImageControlTransformerBlock(
|
ZImageControlTransformerBlock(
|
||||||
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
all_x_embedder = {}
|
all_x_embedder = {}
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
f_patch_size = 1
|
f_patch_size = 1
|
||||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.refiner_control = refiner_control
|
||||||
|
|
||||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
self.control_noise_refiner = nn.ModuleList(
|
if self.refiner_control:
|
||||||
[
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
JointTransformerBlock(
|
[
|
||||||
layer_id,
|
ZImageControlTransformerBlock(
|
||||||
dim,
|
layer_id,
|
||||||
n_heads,
|
dim,
|
||||||
n_kv_heads,
|
n_heads,
|
||||||
multiple_of,
|
n_kv_heads,
|
||||||
ffn_dim_multiplier,
|
multiple_of,
|
||||||
norm_eps,
|
ffn_dim_multiplier,
|
||||||
qk_norm,
|
norm_eps,
|
||||||
modulation=True,
|
qk_norm,
|
||||||
z_image_modulation=True,
|
block_id=layer_id,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
x_attn_mask = None
|
x_attn_mask = None
|
||||||
for layer in self.control_noise_refiner:
|
if not self.refiner_control:
|
||||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
|
||||||
return control_context
|
return control_context
|
||||||
|
|
||||||
|
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
if self.refiner_control:
|
||||||
|
if self.broken:
|
||||||
|
if layer_id == 0:
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if layer_id > 0:
|
||||||
|
out = None
|
||||||
|
for i in range(1, len(self.control_layers)):
|
||||||
|
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if out is None:
|
||||||
|
out = o
|
||||||
|
|
||||||
|
return (out, control_context)
|
||||||
|
else:
|
||||||
|
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
else:
|
||||||
|
return (None, control_context)
|
||||||
|
|
||||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
|||||||
@ -536,6 +536,7 @@ class NextDiT(nn.Module):
|
|||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
device = x[0].device
|
device = x[0].device
|
||||||
|
orig_x = x
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||||
|
|
||||||
padded_img_mask = None
|
padded_img_mask = None
|
||||||
for layer in self.noise_refiner:
|
x_input = x
|
||||||
|
for i, layer in enumerate(self.noise_refiner):
|
||||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||||
|
if "noise_refiner" in patches:
|
||||||
|
for p in patches["noise_refiner"]:
|
||||||
|
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||||
|
if "img" in out:
|
||||||
|
x = out["img"]
|
||||||
|
|
||||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
mask = None
|
mask = None
|
||||||
@ -622,14 +631,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
img_input = img
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
if "img" in out:
|
if "img" in out:
|
||||||
img[:, cap_size[0]:] = out["img"]
|
img[:, cap_size[0]:] = out["img"]
|
||||||
if "txt" in out:
|
if "txt" in out:
|
||||||
|
|||||||
@ -454,6 +454,9 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
@ -243,7 +243,13 @@ class ModelPatchLoader:
|
|||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
sd = z_image_convert(sd)
|
sd = z_image_convert(sd)
|
||||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
config = {}
|
||||||
|
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||||
|
config['n_control_layers'] = 15
|
||||||
|
config['additional_in_dim'] = 17
|
||||||
|
config['refiner_control'] = True
|
||||||
|
config['broken'] = True
|
||||||
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -297,56 +303,86 @@ class DiffSynthCnetPatch:
|
|||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
class ZImageControlPatch:
|
class ZImageControlPatch:
|
||||||
def __init__(self, model_patch, vae, image, strength):
|
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
||||||
self.model_patch = model_patch
|
self.model_patch = model_patch
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.image = image
|
self.image = image
|
||||||
|
self.inpaint_image = inpaint_image
|
||||||
|
self.mask = mask
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.encoded_image = self.encode_latent_cond(image)
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, image):
|
def encode_latent_cond(self, control_image, inpaint_image=None):
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||||
return latent_image
|
if self.model_patch.model.additional_in_dim > 0:
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.zeros_like(latent_image)[:, :1]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||||
|
if inpaint_image is None:
|
||||||
|
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||||
|
|
||||||
|
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||||
|
|
||||||
|
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||||
|
else:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
def __call__(self, kwargs):
|
def __call__(self, kwargs):
|
||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
img = kwargs.get("img")
|
img = kwargs.get("img")
|
||||||
|
img_input = kwargs.get("img_input")
|
||||||
txt = kwargs.get("txt")
|
txt = kwargs.get("txt")
|
||||||
pe = kwargs.get("pe")
|
pe = kwargs.get("pe")
|
||||||
vec = kwargs.get("vec")
|
vec = kwargs.get("vec")
|
||||||
block_index = kwargs.get("block_index")
|
block_index = kwargs.get("block_index")
|
||||||
|
block_type = kwargs.get("block_type", "")
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
inpaint_scaled = None
|
||||||
|
if self.inpaint_image is not None:
|
||||||
|
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
||||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_index = (block_index // 5)
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
cnet_index_float = (block_index / 5)
|
div = round(30 / cnet_blocks)
|
||||||
|
|
||||||
|
cnet_index = (block_index // div)
|
||||||
|
cnet_index_float = (block_index / div)
|
||||||
|
|
||||||
kwargs.pop("img") # we do ops in place
|
kwargs.pop("img") # we do ops in place
|
||||||
kwargs.pop("txt")
|
kwargs.pop("txt")
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
|
||||||
if cnet_index_float > (cnet_blocks - 1):
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
if block_type == "noise_refiner":
|
||||||
|
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
else:
|
||||||
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
|
||||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
if block_type == "noise_refiner":
|
||||||
next_layer = self.temp_data[0] + 1
|
next_layer = self.temp_data[0] + 1
|
||||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
if self.temp_data[1][0] is not None:
|
||||||
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
|
else:
|
||||||
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
|
next_layer = self.temp_data[0] + 1
|
||||||
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
|
||||||
if cnet_index_float == self.temp_data[0]:
|
if cnet_index_float == self.temp_data[0]:
|
||||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
if cnet_blocks == self.temp_data[0] + 1:
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
||||||
|
model_patched.set_model_noise_refiner_patch(patch)
|
||||||
|
model_patched.set_model_double_block_patch(patch)
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user