mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Rename timestep to sigma where it makes sense
This commit is contained in:
parent
97015b6b38
commit
10b791ae4a
@ -54,8 +54,7 @@ class BaseModel(torch.nn.Module):
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
def apply_model(self, x, sigma, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
@ -70,7 +69,7 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = torch.float32
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
@ -80,7 +79,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
model_output = self.diffusion_model(xc, timestep, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ import comfy.conds
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def sampling_function(model, x, sigmas, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
@ -134,7 +134,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, sigmas, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
@ -146,14 +146,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
p = get_area_and_mult(x, x_in, sigmas)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
p = get_area_and_mult(x, x_in, sigmas)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
@ -199,10 +199,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
sigma_ = torch.cat([sigmas] * batch_chunks)
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
c['control'] = control.get_control(input_x, sigma_, c, len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
@ -220,14 +220,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
transformer_options["sigmas"] = sigmas
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "sigma": sigma_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
output = model.apply_model(input_x, sigma_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
@ -249,9 +249,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
uncond = None
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, sigmas, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "input": x, "sigma": sigmas}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
Loading…
Reference in New Issue
Block a user