mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Fix Pyright warnings
This commit is contained in:
parent
9c1e3f7b98
commit
508d286b8f
@ -1,5 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping
|
||||||
|
from comfy.graph import DynamicPrompt
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ class CacheKeySet:
|
|||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
|
|
||||||
def add_keys(node_ids):
|
def add_keys(self, node_ids):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
@ -66,7 +67,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
self.add_keys(node_ids)
|
self.add_keys(node_ids)
|
||||||
|
|
||||||
def include_node_id_in_input(self):
|
def include_node_id_in_input(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
def add_keys(self, node_ids):
|
||||||
@ -131,8 +132,9 @@ class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
|
|||||||
class BasicCache:
|
class BasicCache:
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class):
|
||||||
self.key_class = key_class
|
self.key_class = key_class
|
||||||
self.dynprompt = None
|
self.initialized = False
|
||||||
self.cache_key_set = None
|
self.dynprompt: DynamicPrompt
|
||||||
|
self.cache_key_set: CacheKeySet
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
|
||||||
@ -140,16 +142,17 @@ class BasicCache:
|
|||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
assert self.cache_key_set is not None
|
assert self.initialized
|
||||||
node_ids = self.cache_key_set.all_node_ids()
|
node_ids = self.cache_key_set.all_node_ids()
|
||||||
for subcache in self.subcaches.values():
|
for subcache in self.subcaches.values():
|
||||||
node_ids = node_ids.union(subcache.all_node_ids())
|
node_ids = node_ids.union(subcache.all_node_ids())
|
||||||
return node_ids
|
return node_ids
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
assert self.cache_key_set is not None
|
assert self.initialized
|
||||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||||
to_remove = []
|
to_remove = []
|
||||||
@ -167,12 +170,12 @@ class BasicCache:
|
|||||||
del self.subcaches[key]
|
del self.subcaches[key]
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.cache_key_set is not None
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
self.cache[cache_key] = value
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
def _get_immediate(self, node_id):
|
||||||
if self.cache_key_set is None:
|
if not self.initialized:
|
||||||
return None
|
return None
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
@ -181,7 +184,6 @@ class BasicCache:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _ensure_subcache(self, node_id, children_ids):
|
def _ensure_subcache(self, node_id, children_ids):
|
||||||
assert self.cache_key_set is not None
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
subcache = self.subcaches.get(subcache_key, None)
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
if subcache is None:
|
if subcache is None:
|
||||||
@ -191,7 +193,7 @@ class BasicCache:
|
|||||||
return subcache
|
return subcache
|
||||||
|
|
||||||
def _get_subcache(self, node_id):
|
def _get_subcache(self, node_id):
|
||||||
assert self.cache_key_set is not None
|
assert self.initialized
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
if subcache_key in self.subcaches:
|
if subcache_key in self.subcaches:
|
||||||
return self.subcaches[subcache_key]
|
return self.subcaches[subcache_key]
|
||||||
@ -211,6 +213,7 @@ class HierarchicalCache(BasicCache):
|
|||||||
super().__init__(key_class)
|
super().__init__(key_class)
|
||||||
|
|
||||||
def _get_cache_for(self, node_id):
|
def _get_cache_for(self, node_id):
|
||||||
|
assert self.dynprompt is not None
|
||||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||||
if parent_id is None:
|
if parent_id is None:
|
||||||
return self
|
return self
|
||||||
|
|||||||
10
execution.py
10
execution.py
@ -84,7 +84,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
|
|||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||||
if is_link(input_data) and not input_info.get("rawLink", False):
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
@ -94,7 +94,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
|
|||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached_output[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None:
|
else:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
@ -336,8 +336,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
# Check for conflicts
|
# Check for conflicts
|
||||||
for node_id in new_graph.keys():
|
for node_id in new_graph.keys():
|
||||||
if dynprompt.get_node(node_id) is not None:
|
if dynprompt.get_node(node_id) is not None:
|
||||||
raise Exception("Attempt to add duplicate node %s" % node_id)
|
raise Exception("Attempt to add duplicate node %s. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder." % node_id)
|
||||||
break
|
|
||||||
for node_id, node_info in new_graph.items():
|
for node_id, node_info in new_graph.items():
|
||||||
new_node_ids.append(node_id)
|
new_node_ids.append(node_id)
|
||||||
display_id = node_info.get("override_display_id", unique_id)
|
display_id = node_info.get("override_display_id", unique_id)
|
||||||
@ -518,6 +517,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||||
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
error = {
|
error = {
|
||||||
@ -698,8 +698,6 @@ def validate_inputs(prompt, item, validated):
|
|||||||
"details": details,
|
"details": details,
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
"input_config": info,
|
|
||||||
"received_value": val,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user