mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
7907b8d6be
@ -310,11 +310,13 @@ class ControlLoraOps:
|
|||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -350,12 +352,13 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||||
|
|||||||
@ -1014,6 +1014,16 @@ if args.async_offload:
|
|||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
def current_stream(device):
|
||||||
|
if device is None:
|
||||||
|
return None
|
||||||
|
if is_device_cuda(device):
|
||||||
|
return torch.cuda.current_stream()
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
return torch.xpu.current_stream()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1022,21 +1032,17 @@ def get_offload_stream(device):
|
|||||||
|
|
||||||
if device in STREAMS:
|
if device in STREAMS:
|
||||||
ss = STREAMS[device]
|
ss = STREAMS[device]
|
||||||
s = ss[stream_counter]
|
#Sync the oldest stream in the queue with the current
|
||||||
|
ss[stream_counter].wait_stream(current_stream(device))
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return ss[stream_counter]
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_xpu(device):
|
elif is_device_xpu(device):
|
||||||
@ -1045,18 +1051,14 @@ def get_offload_stream(device):
|
|||||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_stream(device, stream):
|
def sync_stream(device, stream):
|
||||||
if stream is None:
|
if stream is None or current_stream(device) is None:
|
||||||
return
|
return
|
||||||
if is_device_cuda(device):
|
current_stream(device).wait_stream(stream)
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
torch.xpu.current_stream().wait_stream(stream)
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
|
|||||||
@ -655,9 +655,11 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
lowvram_mem_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
|
offloaded = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -674,6 +676,7 @@ class ModelPatcher:
|
|||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
|
lowvram_mem_counter += module_mem
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -699,8 +702,7 @@ class ModelPatcher:
|
|||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
for param in params:
|
offloaded.append((module_mem, n, m, params))
|
||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -741,11 +743,17 @@ class ModelPatcher:
|
|||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
|
for x in offloaded:
|
||||||
|
n = x[1]
|
||||||
|
params = x[3]
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@ -1283,5 +1291,6 @@ class ModelPatcher:
|
|||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
self.unpin_all_weights()
|
||||||
self.detach(unpatch_all=False)
|
self.detach(unpatch_all=False)
|
||||||
|
|
||||||
|
|||||||
127
comfy/ops.py
127
comfy/ops.py
@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||||
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
if offloadable:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
else:
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
wf_context = offload_stream
|
wf_context = offload_stream
|
||||||
else:
|
else:
|
||||||
@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
return weight, bias
|
if offloadable:
|
||||||
|
return weight, bias, offload_stream
|
||||||
|
else:
|
||||||
|
#Legacy function signature
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
if offload_stream is None:
|
||||||
|
return
|
||||||
|
if weight is not None:
|
||||||
|
device = weight.device
|
||||||
|
else:
|
||||||
|
if bias is None:
|
||||||
|
return
|
||||||
|
device = bias.device
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
@ -118,8 +143,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -133,8 +160,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -148,8 +177,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -172,8 +203,10 @@ class disable_weight_init:
|
|||||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -187,8 +220,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -203,11 +238,14 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
offload_stream = None
|
||||||
|
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -223,11 +261,15 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
bias = None
|
||||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
offload_stream = None
|
||||||
|
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -246,10 +288,12 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose2d(
|
x = torch.nn.functional.conv_transpose2d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -268,10 +312,12 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose1d(
|
x = torch.nn.functional.conv_transpose1d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -289,8 +335,11 @@ class disable_weight_init:
|
|||||||
output_dtype = out_dtype
|
output_dtype = out_dtype
|
||||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -361,7 +410,7 @@ def fp8_linear(self, input):
|
|||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@ -372,16 +421,22 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||||
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
|
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
@ -404,8 +459,10 @@ class fp8_ops(manual_cast):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info("Exception during fp8 op: {}".format(e))
|
logging.info("Exception during fp8 op: {}".format(e))
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||||
@ -433,12 +490,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
if out is not None:
|
if out is not None:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if inplace:
|
if inplace:
|
||||||
@ -577,8 +636,10 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
|
|||||||
@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
lp_amax = torch.finfo(dtype).max
|
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
tensor_scaled = tensor.float() / scale
|
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
# lp_amax = torch.finfo(dtype).max
|
||||||
|
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user