修复 ControlLora 和 T2IAdapter 中未使用的 model_options 参数

1. ControlLora.__init__: 添加了 self.model_options = model_options 来存储参数
2. ControlLora.copy: 确保复制时传递 model_options 参数
3. T2IAdapter.copy: 确保复制时传递 device 参数
4. load_t2i_adapter: 从 model_options 中提取 device 参数并传递给 T2IAdapter

这些修改确保了 model_options 参数被正确使用和传递,提高了代码的一致性和可配置性。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
huangqihang06 2026-04-02 11:06:53 +08:00
parent 0c63b4f6e3
commit 050ed68d0d

View File

@ -385,10 +385,11 @@ class ControlLoraOps:
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
def __init__(self, control_weights, global_average_pooling=False, model_options={}):
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.model_options = model_options
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function):
@ -426,7 +427,7 @@ class ControlLora(ControlNet):
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling, model_options=self.model_options)
self.copy_to(c)
return c
@ -902,13 +903,14 @@ class T2IAdapter(ControlBase):
return self.control_merge(control_input, control_prev, x_noisy.dtype)
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm, device=self.device)
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
def load_t2i_adapter(t2i_data, model_options={}):
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
device = model_options.get("device", None)
if 'adapter' in t2i_data:
t2i_data = t2i_data['adapter']
@ -955,4 +957,4 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
if len(unexpected) > 0:
logging.debug("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm, device=device)