Rename timestep to sigma where it makes sense

This commit is contained in:
Maël Kerbiriou 2023-12-08 15:17:18 +01:00
parent 97015b6b38
commit 10b791ae4a
2 changed files with 14 additions and 15 deletions

View File

@ -54,8 +54,7 @@ class BaseModel(torch.nn.Module):
print("model_type", model_type.name) print("model_type", model_type.name)
print("adm", self.adm_channels) print("adm", self.adm_channels)
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, sigma, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1) xc = torch.cat([xc] + [c_concat], dim=1)
@ -70,7 +69,7 @@ class BaseModel(torch.nn.Module):
dtype = torch.float32 dtype = torch.float32
xc = xc.to(dtype) xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float() timestep = self.model_sampling.timestep(sigma).float()
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
@ -80,7 +79,7 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra extra_conds[o] = extra
with precision_scope(comfy.model_management.get_autocast_device(xc.device)): with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() model_output = self.diffusion_model(xc, timestep, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x) return self.model_sampling.calculate_denoised(sigma, model_output, x)

View File

@ -11,7 +11,7 @@ import comfy.conds
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, sigmas, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -134,7 +134,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_uncond_batch(model, cond, uncond, x_in, sigmas, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37 out_count = torch.ones_like(x_in) * 1e-37
@ -146,14 +146,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
to_run = [] to_run = []
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, timestep) p = get_area_and_mult(x, x_in, sigmas)
if p is None: if p is None:
continue continue
to_run += [(p, COND)] to_run += [(p, COND)]
if uncond is not None: if uncond is not None:
for x in uncond: for x in uncond:
p = get_area_and_mult(x, x_in, timestep) p = get_area_and_mult(x, x_in, sigmas)
if p is None: if p is None:
continue continue
@ -199,10 +199,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
c = cond_cat(c) c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks) sigma_ = torch.cat([sigmas] * batch_chunks)
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) c['control'] = control.get_control(input_x, sigma_, c, len(cond_or_uncond))
transformer_options = {} transformer_options = {}
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
@ -220,14 +220,14 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
transformer_options["patches"] = patches transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep transformer_options["sigmas"] = sigmas
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options: if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "sigma": sigma_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, sigma_, **c).chunk(batch_chunks)
del input_x del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
@ -249,9 +249,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
if math.isclose(cond_scale, 1.0): if math.isclose(cond_scale, 1.0):
uncond = None uncond = None
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, sigmas, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "input": x, "sigma": sigmas}
return x - model_options["sampler_cfg_function"](args) return x - model_options["sampler_cfg_function"](args)
else: else:
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale