diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ba670b16d..504a085ce 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -385,10 +385,11 @@ class ControlLoraOps: return x class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options + def __init__(self, control_weights, global_average_pooling=False, model_options={}): ControlBase.__init__(self) self.control_weights = control_weights self.global_average_pooling = global_average_pooling + self.model_options = model_options self.extra_conds += ["y"] def pre_run(self, model, percent_to_timestep_function): @@ -426,7 +427,7 @@ class ControlLora(ControlNet): comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): - c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) + c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling, model_options=self.model_options) self.copy_to(c) return c @@ -902,13 +903,14 @@ class T2IAdapter(ControlBase): return self.control_merge(control_input, control_prev, x_noisy.dtype) def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) + c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm, device=self.device) self.copy_to(c) return c -def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options +def load_t2i_adapter(t2i_data, model_options={}): compression_ratio = 8 upscale_algorithm = 'nearest-exact' + device = model_options.get("device", None) if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] @@ -955,4 +957,4 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options if len(unexpected) > 0: logging.debug("t2i unexpected {}".format(unexpected)) - return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) + return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm, device=device)