mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 22:30:19 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
2710ea1aa2
@ -449,7 +449,9 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
if union_cnet in new_sd:
|
if union_cnet in new_sd:
|
||||||
num_union_modes = new_sd[union_cnet].shape[0]
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux()
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class MistolineControlnetBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
@ -80,7 +80,12 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
if not self.latent_input:
|
if not self.latent_input:
|
||||||
if self.mistoline:
|
if self.mistoline:
|
||||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
|
|||||||
@ -99,30 +99,44 @@ class TopologicalSort:
|
|||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
self.add_node(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
self.add_node(from_node_id)
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
self.blockCount[to_node_id] += 1
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
if unique_id in self.pendingNodes:
|
node_ids = [node_unique_id]
|
||||||
return
|
links = []
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
while len(node_ids) > 0:
|
||||||
for input_name in inputs:
|
unique_id = node_ids.pop()
|
||||||
value = inputs[input_name]
|
if unique_id in self.pendingNodes:
|
||||||
if is_link(value):
|
continue
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
self.pendingNodes[unique_id] = True
|
||||||
continue
|
self.blockCount[unique_id] = 0
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
self.blocking[unique_id] = {}
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
def get_ready_nodes(self):
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def is_cached(self, node_id):
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
return self.output_cache.get(node_id) is not None
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user