more ao and nn assert removal (#170576)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170576
Approved by: https://github.com/liangel-02
ghstack dependencies: #170327, #170328
This commit is contained in:
albanD
2025-12-16 14:59:11 -05:00
committed by PyTorch MergeBot
parent 9eb2fc0e3f
commit 4e3695b61f
22 changed files with 320 additions and 185 deletions

View File

@@ -370,8 +370,6 @@ keep-runtime-typing = true
"torch/fx/**" = ["S101"]
"torch/jit/**" = ["S101"]
"torch/nn/modules/**" = ["S101"]
"torch/nn/parallel/**" = ["S101"]
"torch/nn/utils/**" = ["S101"]
"torch/onnx/_internal/**" = ["S101"]
"torch/testing/**" = ["S101"]
"tools/**" = [

View File

@@ -44,13 +44,15 @@ class ConvReLU1d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, relu)
@@ -59,13 +61,15 @@ class ConvReLU2d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, relu)
@@ -74,13 +78,15 @@ class ConvReLU3d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, relu)
@@ -89,13 +95,15 @@ class LinearReLU(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, relu):
assert (
if not (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(linear)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(linear).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(linear, relu)
@@ -104,13 +112,15 @@ class ConvBn1d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
if not (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(bn).__name__}"
)
super().__init__(conv, bn)
@@ -119,13 +129,15 @@ class ConvBn2d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
if not (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(bn).__name__}"
)
super().__init__(conv, bn)
@@ -134,15 +146,17 @@ class ConvBnReLU1d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__}, "
f"{type_before_parametrizations(bn).__name__}, and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, bn, relu)
@@ -151,15 +165,17 @@ class ConvBnReLU2d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__}, "
f"{type_before_parametrizations(bn).__name__}, and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, bn, relu)
@@ -168,13 +184,15 @@ class ConvBn3d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
if not (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__} and "
f"{type_before_parametrizations(bn).__name__}"
)
super().__init__(conv, bn)
@@ -183,15 +201,17 @@ class ConvBnReLU3d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
if not (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(conv)}"
f"{type_before_parametrizations(bn)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(conv).__name__}, "
f"{type_before_parametrizations(bn).__name__}, and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(conv, bn, relu)
@@ -200,13 +220,15 @@ class BNReLU2d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
if not (
type_before_parametrizations(batch_norm) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(batch_norm).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(batch_norm, relu)
@@ -215,13 +237,15 @@ class BNReLU3d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
if not (
type_before_parametrizations(batch_norm) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), (
f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
f"{type_before_parametrizations(relu)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(batch_norm).__name__} and "
f"{type_before_parametrizations(relu).__name__}"
)
super().__init__(batch_norm, relu)
@@ -230,13 +254,15 @@ class LinearBn1d(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, bn):
assert (
if not (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(bn) == BatchNorm1d
), (
f"Incorrect types for input modules{type_before_parametrizations(linear)}"
f"{type_before_parametrizations(bn)}"
)
):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type_before_parametrizations(linear).__name__} and "
f"{type_before_parametrizations(bn).__name__}"
)
super().__init__(linear, bn)
@@ -245,9 +271,11 @@ class LinearLeakyReLU(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, leaky_relu):
assert type(linear) is Linear and type(leaky_relu) is torch.nn.LeakyReLU, (
f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
)
if not (type(linear) is Linear and type(leaky_relu) is torch.nn.LeakyReLU):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type(linear).__name__} and {type(leaky_relu).__name__}"
)
super().__init__(linear, leaky_relu)
@@ -256,9 +284,11 @@ class LinearTanh(_FusedModule):
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, tanh):
assert type(linear) is Linear and type(tanh) is torch.nn.Tanh, (
f"Incorrect types for input modules{type(linear)}{type(tanh)}"
)
if not (type(linear) is Linear and type(tanh) is torch.nn.Tanh):
raise AssertionError(
f"Incorrect types for input modules: "
f"{type(linear).__name__} and {type(tanh).__name__}"
)
super().__init__(linear, tanh)

View File

@@ -47,7 +47,8 @@ class _ConvNd(nn.modules.conv._ConvNd):
padding_mode,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
if not qconfig:
raise AssertionError("qconfig must be provided for QAT module")
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
@@ -62,14 +63,15 @@ class _ConvNd(nn.modules.conv._ConvNd):
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) is cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if type(mod) is not cls._FLOAT_MODULE:
raise AssertionError(
f"qat.{cls.__name__}.from_float only works for "
f"{cls._FLOAT_MODULE.__name__}, got {type(mod).__name__}"
)
if not hasattr(mod, "qconfig"):
raise AssertionError("Input float module must have qconfig defined")
if not mod.qconfig:
raise AssertionError("Input float module must have a valid qconfig")
if issubclass(type(mod), _FusedModule):
mod = mod[0]
qconfig = mod.qconfig
@@ -111,7 +113,10 @@ class _ConvNd(nn.modules.conv._ConvNd):
# conv relu
if issubclass(cls, _FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
if not hasattr(cls, "_FLOAT_RELU_MODULE"):
raise AssertionError(
f"{cls.__name__} must have _FLOAT_RELU_MODULE attribute"
)
relu = cls._FLOAT_RELU_MODULE()
modules.append(relu)
# pyrefly: ignore [missing-attribute]

View File

@@ -42,7 +42,8 @@ class Linear(nn.Linear):
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(in_features, out_features, bias, **factory_kwargs)
assert qconfig, "qconfig must be provided for QAT module"
if not qconfig:
raise AssertionError("qconfig must be provided for QAT module")
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
@@ -55,14 +56,15 @@ class Linear(nn.Linear):
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if type_before_parametrizations(mod) != cls._FLOAT_MODULE:
raise AssertionError(
f"qat.{cls.__name__}.from_float only works for "
f"{cls._FLOAT_MODULE.__name__}, got {type_before_parametrizations(mod).__name__}"
)
if not hasattr(mod, "qconfig"):
raise AssertionError("Input float module must have qconfig defined")
if not mod.qconfig:
raise AssertionError("Input float module must have a valid qconfig")
if type_before_parametrizations(mod) == LinearReLU:
mod = mod[0]

View File

@@ -244,7 +244,8 @@ class Conv3d(nnq.Conv3d):
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
)
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
if padding_mode == "reflect":
raise AssertionError("Conv3d does not support reflection padding")
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)

View File

@@ -576,7 +576,8 @@ def leaky_relu(
See :class:`~torch.nn.LeakyReLU` for more details.
"""
if scale is not None and zero_point is not None:
assert not inplace, "Cannot rescale with `inplace`"
if inplace:
raise AssertionError("Cannot rescale with `inplace`")
output = torch._empty_affine_quantized(
input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype
)

View File

@@ -126,7 +126,10 @@ class Quantize(torch.nn.Module):
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
assert hasattr(mod, "activation_post_process")
if not hasattr(mod, "activation_post_process"):
raise AssertionError(
f"Module {type(mod).__name__} must have activation_post_process attribute"
)
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(
scale.float().item(),

View File

@@ -247,9 +247,10 @@ class _ConvNd(WeightedQuantizedModule):
if weight_post_process is None:
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
assert weight_post_process.dtype == torch.qint8, (
"Weight observer must have a dtype of qint8"
)
if weight_post_process.dtype != torch.qint8:
raise AssertionError(
f"Weight observer must have a dtype of qint8, got {weight_post_process.dtype}"
)
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(
@@ -290,23 +291,18 @@ class _ConvNd(WeightedQuantizedModule):
mod.bn.weight,
mod.bn.bias,
)
assert hasattr(mod, "activation_post_process"), (
"Input QAT module must have observer attached"
)
if not hasattr(mod, "activation_post_process"):
raise AssertionError("Input QAT module must have observer attached")
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
assert type(mod) is cls._FLOAT_MODULE, (
" nnq."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
+ " but got:"
+ str(type(mod))
)
assert hasattr(mod, "qconfig"), (
"Input float module must have qconfig defined."
)
if type(mod) is not cls._FLOAT_MODULE:
raise AssertionError(
f"nnq.{cls.__name__}.from_float only works for "
f"{cls._FLOAT_MODULE.__name__} but got: {type(mod).__name__}"
)
if not hasattr(mod, "qconfig"):
raise AssertionError("Input float module must have qconfig defined.")
activation_post_process = (
None
if not hasattr(mod, "activation_post_process")
@@ -669,7 +665,8 @@ class Conv3d(_ConvNd):
device=None,
dtype=None,
):
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
if padding_mode == "reflect":
raise AssertionError("Conv3d does not support reflection padding")
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)

View File

@@ -288,9 +288,11 @@ class QFunctional(torch.nn.Module):
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) is FloatFunctional, (
"QFunctional.from_float expects an instance of FloatFunctional"
)
if type(mod) is not FloatFunctional:
raise AssertionError(
f"QFunctional.from_float expects an instance of FloatFunctional, "
f"got {type(mod).__name__}"
)
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
new_mod = QFunctional()
new_mod.scale = float(scale)

View File

@@ -26,7 +26,11 @@ class LinearBlockSparsePattern:
prev_col_block_size: int = 4
def __init__(self, row_block_size: int = 1, col_block_size: int = 4):
assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
if not _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
raise AssertionError(
f"Invalid linear block sparse pattern: "
f"row_block_size={row_block_size}, col_block_size={col_block_size}"
)
LinearBlockSparsePattern.rlock.acquire()
LinearBlockSparsePattern.prev_row_block_size = (
LinearBlockSparsePattern.row_block_size

View File

@@ -10,9 +10,8 @@ from torch.nn.parallel import comm
class Broadcast(Function):
@staticmethod
def forward(ctx, target_gpus, *inputs):
assert all(i.device.type != "cpu" for i in inputs), (
"Broadcast function not implemented for CPU tensors"
)
if not all(i.device.type != "cpu" for i in inputs):
raise AssertionError("Broadcast function not implemented for CPU tensors")
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.target_gpus = target_gpus
if len(inputs) == 0:
@@ -55,9 +54,8 @@ class ReduceAddCoalesced(Function):
class Gather(Function):
@staticmethod
def forward(ctx, target_device, dim, *inputs):
assert all(i.device.type != "cpu" for i in inputs), (
"Gather function not implemented for CPU tensors"
)
if not all(i.device.type != "cpu" for i in inputs):
raise AssertionError("Gather function not implemented for CPU tensors")
if target_device == "cpu":
ctx.target_device = "cpu"
else:
@@ -123,7 +121,11 @@ def _get_stream(device: torch.device):
global _streams
if device.type == "cpu" or not torch.accelerator.is_available():
return None
assert torch.accelerator.current_accelerator().type == device.type
if torch.accelerator.current_accelerator().type != device.type:
raise AssertionError(
f"Expected current accelerator type {torch.accelerator.current_accelerator().type} "
f"to match device type {device.type}"
)
if _streams is None:
_streams = [None] * torch.accelerator.device_count()
if _streams[device.index] is None:

View File

@@ -86,7 +86,10 @@ def reduce_add(inputs, destination=None):
input_size = inputs[0].size()
root_index = None # index of input tensor that already is on the correct device
for i, inp in enumerate(inputs):
assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs"
if inp.device.type == "cpu":
raise AssertionError(
f"reduce_add expects all inputs to be on GPUs, but input {i} is on CPU"
)
if inp.get_device() == destination:
root_index = i
if inp.size() != input_size:

View File

@@ -279,7 +279,8 @@ def data_parallel(
inputs = ((),)
module_kwargs = ({},)
assert module_kwargs is not None
if module_kwargs is None:
raise AssertionError("module_kwargs should not be None after scatter_kwargs")
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])

View File

@@ -276,11 +276,13 @@ class _DDPSink(Function):
class _DDPJoinHook(JoinHook):
def __init__(self, ddp, divide_by_initial_world_size):
"""Set config variables for internal usage."""
assert isinstance(ddp, DistributedDataParallel), (
"DDP join hook requires passing in a DistributedDataParallel "
"instance as the state"
)
assert ddp.logger is not None
if not isinstance(ddp, DistributedDataParallel):
raise AssertionError(
"DDP join hook requires passing in a DistributedDataParallel "
f"instance as the state, got {type(ddp).__name__}"
)
if ddp.logger is None:
raise AssertionError("ddp.logger must not be None")
ddp.logger._set_uneven_input_join()
self.ddp = ddp
self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
@@ -1300,7 +1302,10 @@ class DistributedDataParallel(Module, Joinable):
)
if self.static_graph:
self.reducer._set_static_graph()
assert self.logger is not None
if self.logger is None:
raise AssertionError(
"self.logger must not be None when static_graph is True"
)
self.logger._set_static_graph()
def _build_params_for_reducer(self):
@@ -1529,7 +1534,8 @@ class DistributedDataParallel(Module, Joinable):
return inputs, kwargs
if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
if self.logger is None:
raise AssertionError("self.logger must not be None")
self.logger.set_runtime_stats_and_log()
self.reducer.prepare_for_forward()
@@ -1943,7 +1949,8 @@ class DistributedDataParallel(Module, Joinable):
ensure appropriate synchronization when manipulating GPU
buffers in the forward pass.
"""
assert callable(hook)
if not callable(hook):
raise AssertionError(f"hook must be callable, got {type(hook).__name__}")
self.buffer_hook = _BufferCommHook(
buffer_comm_hook=hook,
buffer_comm_hook_state=state,
@@ -2026,7 +2033,8 @@ class DistributedDataParallel(Module, Joinable):
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
"""
self._check_comm_hook(hook)
assert self.logger is not None
if self.logger is None:
raise AssertionError("self.logger must not be None")
self.logger._set_comm_hook_name(hook.__qualname__)
self._comm_hooks.append((hook, state))
dist._register_comm_hook(self.reducer, state, hook)
@@ -2054,7 +2062,8 @@ class DistributedDataParallel(Module, Joinable):
>>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
"""
assert self.logger is not None
if self.logger is None:
raise AssertionError("self.logger must not be None")
self.logger._set_comm_hook_name(str(comm_hook_type))
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
@@ -2331,7 +2340,8 @@ class DistributedDataParallel(Module, Joinable):
these metrics are.
This is a prototype interface and subject to change in the future.
"""
assert self.logger is not None
if self.logger is None:
raise AssertionError("self.logger must not be None")
ddp_logging_data = self.logger._get_ddp_logging_data()
return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
@@ -2372,7 +2382,8 @@ class DistributedDataParallel(Module, Joinable):
self.static_graph = True
self._static_graph_delay_allreduce_enqueued = False
self.reducer._set_static_graph()
assert self.logger is not None
if self.logger is None:
raise AssertionError("self.logger must not be None")
self.logger._set_static_graph()
if self.find_unused_parameters:
warnings.warn(

View File

@@ -46,20 +46,31 @@ def parallel_apply(
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs), (
f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
)
if len(modules) != len(inputs):
raise AssertionError(
f"The number of modules {len(modules)} is not equal to "
f"the number of inputs {len(inputs)}"
)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
if len(modules) != len(kwargs_tup):
raise AssertionError(
f"The number of modules {len(modules)} is not equal to "
f"the number of kwargs_tup {len(kwargs_tup)}"
)
else:
kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
if len(modules) != len(devices):
raise AssertionError(
f"The number of modules {len(modules)} is not equal to "
f"the number of devices {len(devices)}"
)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.accelerator.current_stream(x) for x in devices]
assert torch.accelerator.is_available(), "No available accelerator found."
if not torch.accelerator.is_available():
raise AssertionError("No available accelerator found.")
device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr]
lock = threading.Lock()
results = {}

View File

@@ -18,7 +18,10 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
if func is F.conv2d:
return conv2dOpt
else:
assert func is F.conv3d
if func is not F.conv3d:
raise AssertionError(
f"Expected func to be F.conv1d, F.conv2d, or F.conv3d, got {func}"
)
return conv3dOpt

View File

@@ -248,7 +248,11 @@ class NamedMemberAccessor:
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
if len(names) != len(values):
raise AssertionError(
f"names and values must have the same length, "
f"got {len(names)} names and {len(values)} values"
)
for name, value in zip(names, values, strict=True):
self.set_tensor(name, value)
@@ -294,7 +298,11 @@ class NamedMemberAccessor:
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
if len(names) != len(values):
raise AssertionError(
f"names and values must have the same length, "
f"got {len(names)} names and {len(values)} values"
)
return [
self.swap_tensor(name, value, allow_missing=allow_missing)

View File

@@ -35,10 +35,12 @@ def fuse_conv_bn_eval(
.. note::
Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (conv.training or bn.training), "Fusion only for eval!"
if conv.training or bn.training:
raise AssertionError("Fusion only for eval!")
fused_conv = copy.deepcopy(conv)
assert bn.running_mean is not None and bn.running_var is not None
if bn.running_mean is None or bn.running_var is None:
raise AssertionError("bn.running_mean and bn.running_var must not be None")
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight,
fused_conv.bias,
@@ -122,7 +124,8 @@ def fuse_linear_bn_eval(
.. note::
Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (linear.training or bn.training), "Fusion only for eval!"
if linear.training or bn.training:
raise AssertionError("Fusion only for eval!")
fused_linear = copy.deepcopy(linear)
"""
@@ -135,11 +138,14 @@ def fuse_linear_bn_eval(
2. the number of features in bn is 1
Otherwise, skip the folding path
"""
assert linear.out_features == bn.num_features or bn.num_features == 1, (
"To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
)
if linear.out_features != bn.num_features and bn.num_features != 1:
raise AssertionError(
f"To fuse, linear.out_features == bn.num_features or bn.num_features == 1, "
f"got linear.out_features={linear.out_features} and bn.num_features={bn.num_features}"
)
assert bn.running_mean is not None and bn.running_var is not None
if bn.running_mean is None or bn.running_var is None:
raise AssertionError("bn.running_mean and bn.running_var must not be None")
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
fused_linear.weight,
fused_linear.bias,

View File

@@ -440,7 +440,10 @@ class _SpectralNorm(Module):
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
# Precondition
assert weight.ndim > 1
if weight.ndim <= 1:
raise AssertionError(
f"Expected weight to have more than 1 dimension, got {weight.ndim}"
)
if self.dim != 0:
# permute dim to front
@@ -484,7 +487,10 @@ class _SpectralNorm(Module):
# (i.e., the `u` and `v` vectors) are changed in the second forward.
# Precondition
assert weight_mat.ndim > 1
if weight_mat.ndim <= 1:
raise AssertionError(
f"Expected weight_mat to have more than 1 dimension, got {weight_mat.ndim}"
)
for _ in range(n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`

View File

@@ -384,7 +384,8 @@ def _inject_property(module: Module, tensor_name: str) -> None:
"""
# We check the precondition.
# This should never fire if register_parametrization is correctly implemented
assert not hasattr(module, tensor_name)
if hasattr(module, tensor_name):
raise AssertionError(f"Module already has an attribute named '{tensor_name}'")
@torch.jit.unused
def get_cached_parametrization(parametrization) -> Tensor:
@@ -609,7 +610,11 @@ def register_parametrization(
# else right_inverse is assumed to be the identity
# add the new parametrization to the parametrization list
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
if not isinstance(module.parametrizations, ModuleDict):
raise AssertionError(
f"Expected module.parametrizations to be a ModuleDict, "
f"got {type(module.parametrizations).__name__}"
)
module.parametrizations[tensor_name].append(parametrization) # type: ignore[operator]
# If unsafe was True in previous parametrization, keep it enabled
module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr, operator]
@@ -633,7 +638,11 @@ def register_parametrization(
# Add a property into the class
_inject_property(module, tensor_name)
# Add a ParametrizationList
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
if not isinstance(module.parametrizations, ModuleDict):
raise AssertionError(
f"Expected module.parametrizations to be a ModuleDict, "
f"got {type(module.parametrizations).__name__}"
)
module.parametrizations[tensor_name] = parametrizations
else:
raise ValueError(
@@ -698,12 +707,20 @@ def remove_parametrizations(
)
# Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
if not isinstance(module.parametrizations, ModuleDict):
raise AssertionError(
f"Expected module.parametrizations to be a ModuleDict, "
f"got {type(module.parametrizations).__name__}"
)
parametrizations = module.parametrizations[tensor_name]
# pyrefly: ignore [invalid-argument]
if parametrizations.is_tensor:
original = parametrizations.original
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
if not isinstance(original, torch.Tensor):
raise AssertionError(
f"Expected original to be a Tensor (is_tensor promised us a Tensor), "
f"got {type(original).__name__}"
)
if leave_parametrized:
with torch.no_grad():
t = getattr(module, tensor_name)
@@ -792,14 +809,22 @@ def transfer_parametrizations_and_params(
Module: to_module
"""
if is_parametrized(from_module):
assert isinstance(from_module.parametrizations, ModuleDict) # for mypy
if not isinstance(from_module.parametrizations, ModuleDict):
raise AssertionError(
f"Expected from_module.parametrizations to be a ModuleDict, "
f"got {type(from_module.parametrizations).__name__}"
)
# get list of all params or the single param to transfer
parameters_to_transfer: list | ModuleDict = (
from_module.parametrizations if tensor_name is None else [tensor_name]
)
assert hasattr(parameters_to_transfer, "__iter__") # for mypy
if not hasattr(parameters_to_transfer, "__iter__"):
raise AssertionError(
f"Expected parameters_to_transfer to be iterable, "
f"got {type(parameters_to_transfer).__name__}"
)
for parameter_name in parameters_to_transfer:
# initialize the to-be-transferred param in to_module if it doesn't exist already
if not hasattr(to_module, parameter_name):
@@ -814,7 +839,11 @@ def transfer_parametrizations_and_params(
parameter_name
]:
register_parametrization(to_module, parameter_name, param_func)
assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
if not isinstance(to_module.parametrizations, ModuleDict):
raise AssertionError(
f"Expected to_module.parametrizations to be a ModuleDict, "
f"got {type(to_module.parametrizations).__name__}"
)
# make values match, original values can be stored in either original or
# original0, original1..., need to check both cases

View File

@@ -64,9 +64,10 @@ class BasePruningMethod(ABC):
"""
# to carry out the multiplication, the mask needs to have been computed,
# so the pruning method must know what tensor it's operating on
assert self._tensor_name is not None, (
f"Module {module} has to be pruned"
) # this gets set in apply()
if self._tensor_name is None:
raise AssertionError(
f"Module {module} has to be pruned"
) # this gets set in apply()
mask = getattr(module, self._tensor_name + "_mask")
orig = getattr(module, self._tensor_name + "_orig")
pruned_tensor = mask.to(dtype=orig.dtype) * orig
@@ -110,10 +111,11 @@ class BasePruningMethod(ABC):
old_method = hook
hooks_to_remove.append(k)
found += 1
assert found <= 1, (
f"Avoid adding multiple pruning hooks to the\
same tensor {name} of module {module}. Use a PruningContainer."
)
if found > 1:
raise AssertionError(
f"Avoid adding multiple pruning hooks to the "
f"same tensor {name} of module {module}. Use a PruningContainer."
)
for k in hooks_to_remove:
del module._forward_pre_hooks[k]
@@ -154,9 +156,11 @@ class BasePruningMethod(ABC):
orig = getattr(module, name)
if importance_scores is not None:
assert importance_scores.shape == orig.shape, (
f"importance_scores should have the same shape as parameter {name} of {module}"
)
if importance_scores.shape != orig.shape:
raise AssertionError(
f"importance_scores should have the same shape as parameter "
f"{name} of {module}, got {importance_scores.shape} vs {orig.shape}"
)
else:
importance_scores = orig
@@ -223,9 +227,11 @@ class BasePruningMethod(ABC):
pruned version of tensor ``t``.
"""
if importance_scores is not None:
assert importance_scores.shape == t.shape, (
"importance_scores should have the same shape as tensor t"
)
if importance_scores.shape != t.shape:
raise AssertionError(
f"importance_scores should have the same shape as tensor t, "
f"got {importance_scores.shape} vs {t.shape}"
)
else:
importance_scores = t
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
@@ -242,9 +248,10 @@ class BasePruningMethod(ABC):
Pruning itself is NOT undone or reversed!
"""
# before removing pruning from a tensor, it has to have been applied
assert self._tensor_name is not None, (
f"Module {module} has to be pruned before pruning can be removed"
) # this gets set in apply()
if self._tensor_name is None:
raise AssertionError(
f"Module {module} has to be pruned before pruning can be removed"
) # this gets set in apply()
# to update module[name] to latest trained weights
weight = self.apply_mask(module) # masked weights
@@ -803,7 +810,11 @@ class CustomFromMask(BasePruningMethod):
self.mask = mask
def compute_mask(self, t, default_mask):
assert default_mask.shape == self.mask.shape
if default_mask.shape != self.mask.shape:
raise AssertionError(
f"default_mask shape {default_mask.shape} must match "
f"self.mask shape {self.mask.shape}"
)
mask = default_mask * self.mask.to(dtype=default_mask.dtype)
return mask

View File

@@ -238,7 +238,8 @@ def _packed_sequence_init_args(
# support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
else:
assert isinstance(data, (list, tuple)) and len(data) == 2
if not (isinstance(data, (list, tuple)) and len(data) == 2):
raise AssertionError("Expected data to be a list or tuple of length 2")
return data[0], data[1], sorted_indices, unsorted_indices