mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
Merge 050ed68d0d into b615af1c65
This commit is contained in:
commit
28cec73b42
@ -385,10 +385,11 @@ class ControlLoraOps:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}):
|
||||||
ControlBase.__init__(self)
|
ControlBase.__init__(self)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.model_options = model_options
|
||||||
self.extra_conds += ["y"]
|
self.extra_conds += ["y"]
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@ -426,7 +427,7 @@ class ControlLora(ControlNet):
|
|||||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling, model_options=self.model_options)
|
||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -902,13 +903,14 @@ class T2IAdapter(ControlBase):
|
|||||||
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm, device=self.device)
|
||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
def load_t2i_adapter(t2i_data, model_options={}):
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
device = model_options.get("device", None)
|
||||||
|
|
||||||
if 'adapter' in t2i_data:
|
if 'adapter' in t2i_data:
|
||||||
t2i_data = t2i_data['adapter']
|
t2i_data = t2i_data['adapter']
|
||||||
@ -955,4 +957,4 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("t2i unexpected {}".format(unexpected))
|
logging.debug("t2i unexpected {}".format(unexpected))
|
||||||
|
|
||||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm, device=device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user