mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Rename timestep to sigma where it makes sense
This commit is contained in:
parent
97015b6b38
commit
10b791ae4a
@ -54,8 +54,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
print("model_type", model_type.name)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, sigma, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
|
||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||||
@ -70,7 +69,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
timestep = self.model_sampling.timestep(sigma).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
@ -80,7 +79,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
model_output = self.diffusion_model(xc, timestep, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
|
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import comfy.conds
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, sigmas, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -134,7 +134,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, sigmas, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
@ -146,14 +146,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, sigmas)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
if uncond is not None:
|
if uncond is not None:
|
||||||
for x in uncond:
|
for x in uncond:
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, sigmas)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -199,10 +199,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
input_x = torch.cat(input_x)
|
||||||
c = cond_cat(c)
|
c = cond_cat(c)
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
sigma_ = torch.cat([sigmas] * batch_chunks)
|
||||||
|
|
||||||
if control is not None:
|
if control is not None:
|
||||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
c['control'] = control.get_control(input_x, sigma_, c, len(cond_or_uncond))
|
||||||
|
|
||||||
transformer_options = {}
|
transformer_options = {}
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
@ -220,14 +220,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
transformer_options["patches"] = patches
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
transformer_options["sigmas"] = timestep
|
transformer_options["sigmas"] = sigmas
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
if 'model_function_wrapper' in model_options:
|
||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "sigma": sigma_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
output = model.apply_model(input_x, sigma_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
@ -249,9 +249,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
if math.isclose(cond_scale, 1.0):
|
if math.isclose(cond_scale, 1.0):
|
||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, sigmas, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "input": x, "sigma": sigmas}
|
||||||
return x - model_options["sampler_cfg_function"](args)
|
return x - model_options["sampler_cfg_function"](args)
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user