diff --git a/comfy/samplers.py b/comfy/samplers.py index 9eee25a92..e059374d3 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -17,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 + if 'timestep_start' in cond[1]: + timestep_start = cond[1]['timestep_start'] + if timestep_in > timestep_start: + return None + if 'timestep_end' in cond[1]: + timestep_end = cond[1]['timestep_end'] + if timestep_in < timestep_end: + return None if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -428,6 +436,35 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] +def calculate_start_end_timesteps(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + if 'start_percent' in x[1]: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) + if 'end_percent' in x[1]: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + + if (timestep_start is not None) or (timestep_end is not None): + n = x[1].copy() + if (timestep_start is not None): + n['timestep_start'] = timestep_start + if (timestep_end is not None): + n['timestep_end'] = timestep_end + conds[t] = [x[0], n] + +def pre_run_control(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) + if 'control' in x[1]: + x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) + def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] @@ -571,13 +608,18 @@ class KSampler: resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + calculate_start_end_timesteps(self.model_wrap, negative) + calculate_start_end_timesteps(self.model_wrap, positive) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + pre_run_control(self.model_wrap, negative + positive) + + apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.is_adm(): diff --git a/comfy/sd.py b/comfy/sd.py index 1f364dd1f..64ab7cc63 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -673,16 +673,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) -class ControlNet: - def __init__(self, control_model, global_average_pooling=False, device=None): - self.control_model = control_model +class ControlBase: + def __init__(self, device=None): self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + if device is None: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -690,6 +731,13 @@ class ControlNet: if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: @@ -737,35 +785,17 @@ class ControlNet: out['input'] = control_prev['input'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - def copy(self): c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() + out = super().get_models() out.append(self.control_model) return out + def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) @@ -870,24 +900,25 @@ def load_controlnet(ckpt_path, model=None): control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control -class T2IAdapter: +class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in - self.strength = 1.0 - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None self.control_input = None - self.cond_hint_original = None - self.cond_hint = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -932,33 +963,11 @@ class T2IAdapter: out['output'] = control_prev['output'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out def load_t2i_adapter(t2i_data): keys = t2i_data.keys() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f9252ea0b..abd182e6e 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -37,12 +37,23 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) + free_memory = model_management.get_free_memory(device) + + tile = 512 + overlap = 32 + + oom = True + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e - tile = 128 + 64 - overlap = 8 - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) diff --git a/cuda_malloc.py b/cuda_malloc.py index faee91a34..250d827da 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -37,7 +37,7 @@ def get_gpu_names(): return set() def cuda_malloc_supported(): - blacklist = {"GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"} + blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"} try: names = get_gpu_names() except: diff --git a/nodes.py b/nodes.py index 01a4a8d83..a2f4347f5 100644 --- a/nodes.py +++ b/nodes.py @@ -207,6 +207,28 @@ class ConditioningZeroOut: c.append(n) return (c, ) +class ConditioningSetTimestepRange: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "set_range" + + CATEGORY = "advanced/conditioning" + + def set_range(self, conditioning, start, end): + c = [] + for t in conditioning: + d = t[1].copy() + d['start_percent'] = 1.0 - start + d['end_percent'] = 1.0 - end + n = [t[0], d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -703,9 +725,58 @@ class ControlNetApply: if 'control' in t[1]: c_net.set_previous_controlnet(t[1]['control']) n[1]['control'] = c_net + n[1]['control_apply_to_uncond'] = True c.append(n) return (c, ) + +class ControlNetApplyAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent): + if strength == 0: + return (positive, negative) + + control_hint = image.movedim(-1,1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + + class UNETLoader: @classmethod def INPUT_TYPES(s): @@ -1550,6 +1621,7 @@ NODE_CLASS_MAPPINGS = { "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning, "ControlNetApply": ControlNetApply, + "ControlNetApplyAdvanced": ControlNetApplyAdvanced, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, "StyleModelLoader": StyleModelLoader, @@ -1567,6 +1639,7 @@ NODE_CLASS_MAPPINGS = { "SaveLatent": SaveLatent, "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningSetTimestepRange": ConditioningSetTimestepRange, "SavePreviewLatent": SavePreviewLatent, } @@ -1596,6 +1669,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", + "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask",