diff --git a/comfy/controlnet.py b/comfy/controlnet.py index c0f9b6511..860891965 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -449,7 +449,9 @@ def load_controlnet_flux_instantx(sd): if union_cnet in new_sd: num_union_modes = new_sd[union_cnet].shape[0] - control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4 + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.Flux() diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index d8b776129..c033dea52 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -52,7 +52,7 @@ class MistolineControlnetBlock(nn.Module): class ControlNetFlux(Flux): - def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 @@ -80,7 +80,12 @@ class ControlNetFlux(Flux): self.gradient_checkpointing = False self.latent_input = latent_input - self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + if control_latent_channels is None: + control_latent_channels = self.in_channels + else: + control_latent_channels *= 2 * 2 #patch size + + self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if not self.latent_input: if self.mistoline: self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index b53e10f3f..0b5bf1899 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -99,30 +99,44 @@ class TopologicalSort: self.add_strong_link(from_node_id, from_socket, to_node_id) def add_strong_link(self, from_node_id, from_socket, to_node_id): - self.add_node(from_node_id) - if to_node_id not in self.blocking[from_node_id]: - self.blocking[from_node_id][to_node_id] = {} - self.blockCount[to_node_id] += 1 - self.blocking[from_node_id][to_node_id][from_socket] = True + if not self.is_cached(from_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True - def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): - if unique_id in self.pendingNodes: - return - self.pendingNodes[unique_id] = True - self.blockCount[unique_id] = 0 - self.blocking[unique_id] = {} + def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): + node_ids = [node_unique_id] + links = [] - inputs = self.dynprompt.get_node(unique_id)["inputs"] - for input_name in inputs: - value = inputs[input_name] - if is_link(value): - from_node_id, from_socket = value - if subgraph_nodes is not None and from_node_id not in subgraph_nodes: - continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) - is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if include_lazy or not is_lazy: - self.add_strong_link(from_node_id, from_socket, unique_id) + while len(node_ids) > 0: + unique_id = node_ids.pop() + if unique_id in self.pendingNodes: + continue + + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] + if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): + node_ids.append(from_node_id) + links.append((from_node_id, from_socket, unique_id)) + + for link in links: + self.add_strong_link(*link) + + def is_cached(self, node_id): + return False def get_ready_nodes(self): return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] @@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort): self.output_cache = output_cache self.staged_node_id = None - def add_strong_link(self, from_node_id, from_socket, to_node_id): - if self.output_cache.get(from_node_id) is not None: - # Nothing to do - return - super().add_strong_link(from_node_id, from_socket, to_node_id) + def is_cached(self, node_id): + return self.output_cache.get(node_id) is not None def stage_node_execution(self): assert self.staged_node_id is None