diff --git a/benchmarks/functional_autograd_benchmark/README.md b/benchmarks/functional_autograd_benchmark/README.md new file mode 100644 index 00000000000..a5f106fec67 --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/README.md @@ -0,0 +1,48 @@ +# Benchmarking tool for the autograd API + +This folder contain a set of self-contained scripts that allow to benchmark the autograd with different common models. +It is designed to run the benchmark before and after your change and will generate a table to share on the PR. + +To do so, you can use `functional_autograd_benchmark.py` to run the benchmarks before your change (using as output `before.txt`) and after your change (using as output `after.txt`). +You can then use `compare.py` to get a markdown table comparing the two runs. + +The default arguments of `functional_autograd_benchmark.py` should be used in general. You can change them though to force a given device or force running even the (very) slow settings. + +### Sample usage + +```bash +# Make sure you compile pytorch in release mode and with the same flags before/after +export DEBUG=0 +# When running on CPU, it might be required to limit the number of cores to avoid oversubscription +export OMP_NUM_THREADS=10 + +# Compile pytorch with the base revision +git checkout master +python setup.py develop + +# Run the benchmark for the base +# This will use the GPU if available. +pushd benchmarks/functional_autograd_benchmark +python functional_autograd_benchmark.py --output before.txt + +# Compile pytorch with your change +popd +git checkout your_feature_branch +python setup.py develop + +# Run the benchmark for the new version +pushd benchmarks/functional_autograd_benchmark +python functional_autograd_benchmark.py --output after.txt + +# Get the markdown table that you can paste in your github PR +python compare.py + +popd + +``` + +### Files in this folder: +- `functional_autograd_benchmark.py` is the main entry point to run the benchmark. +- `compare.py` is the entry point to run the comparison script that generates a markdown table. +- `torchaudio_models.py` and `torchvision_models.py` contains code extracted from torchaudio and torchvision to be able to run the models without having a specific version of these libraries installed. +- `ppl_models.py`, `vision_models.py` and `audio_text_models.py` contain all the getter functions used for the benchmark. diff --git a/benchmarks/functional_autograd_benchmark/audio_text_models.py b/benchmarks/functional_autograd_benchmark/audio_text_models.py new file mode 100644 index 00000000000..938e677ac38 --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/audio_text_models.py @@ -0,0 +1,122 @@ +import torch +from torch import nn, Tensor + +import torchaudio_models as models + +from utils import extract_weights, load_weights, GetterReturnType + +def get_wav2letter(device: torch.device) -> GetterReturnType: + N = 10 + input_frames = 700 + vocab_size = 28 + model = models.Wav2Letter(num_classes=vocab_size) + criterion = torch.nn.NLLLoss() + model.to(device) + params, names = extract_weights(model) + + inputs = torch.rand([N, 1, input_frames], device=device) + labels = torch.rand(N, 3, device=device).mul(vocab_size).long() + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out = model(inputs) + + loss = criterion(out, labels) + return loss + + return forward, params + +def get_deepspeech(device: torch.device) -> GetterReturnType: + sample_rate = 16000 + window_size = 0.02 + window = "hamming" + audio_conf = dict(sample_rate=sample_rate, + window_size=window_size, + window=window, + noise_dir=None) + + N = 10 + num_classes = 10 + spectrogram_size = 161 + # Commented are the original sizes in the code + seq_length = 500 # 1343 + target_length = 10 # 50 + labels = torch.rand(num_classes, device=device) + inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device) + # Sequence length for each input + inputs_sizes = torch.rand(N, device=device).mul(seq_length * 0.1).add(seq_length * 0.8) + targets = torch.rand(N, target_length, device=device) + targets_sizes = torch.full((N,), target_length, dtype=torch.int, device=device) + + model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5, + audio_conf=audio_conf, bidirectional=True) + model = model.to(device) + criterion = nn.CTCLoss() + params, names = extract_weights(model) + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out, out_sizes = model(inputs, inputs_sizes) + out = out.transpose(0, 1) # For ctc loss + + loss = criterion(out, targets, out_sizes, targets_sizes) + return loss + + return forward, params + +def get_transformer(device: torch.device) -> GetterReturnType: + # For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128. + N = 64 + seq_length = 128 + ntoken = 50 + model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2) + model.to(device) + criterion = nn.NLLLoss() + params, names = extract_weights(model) + + data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long() + inputs = data.narrow(1, 0, seq_length) + targets = data.narrow(1, 1, seq_length) + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out = model(inputs) + + loss = criterion(out.reshape(N * seq_length, ntoken), targets.reshape(N * seq_length)) + return loss + + return forward, params + +def get_multiheadattn(device: torch.device) -> GetterReturnType: + # From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10 + embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 + # Build torchtext MultiheadAttention module + in_proj = models.InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False)) + + model = models.MultiheadAttentionContainer(nhead, in_proj, + models.ScaledDotProduct(), + torch.nn.Linear(embed_dim, embed_dim, bias=False)) + model.to(device) + params, names = extract_weights(model) + + query = torch.rand((tgt_len, bsz, embed_dim), device=device) + key = value = torch.rand((src_len, bsz, embed_dim), device=device) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len), device=device).to(torch.bool) + bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device) + + attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead)) + bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) + bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + mha_output, attn_weights = model(query, key, value, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v) + + # Don't test any specific loss, just backprop ones for both outputs + loss = mha_output.sum() + attn_weights.sum() + + return loss + + return forward, params diff --git a/benchmarks/functional_autograd_benchmark/compare.py b/benchmarks/functional_autograd_benchmark/compare.py new file mode 100644 index 00000000000..c2c4ef6c95d --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/compare.py @@ -0,0 +1,45 @@ +import argparse +from collections import defaultdict + +from utils import to_markdown_table, from_markdown_table + +def main(): + parser = argparse.ArgumentParser("Main script to compare results from the benchmarks") + parser.add_argument("--before", type=str, default="before.txt", help="Text file containing the times to use as base") + parser.add_argument("--after", type=str, default="after.txt", help="Text file containing the times to use as new version") + parser.add_argument("--output", type=str, default="", help="Text file where to write the output") + args = parser.parse_args() + + with open(args.before, "r") as f: + content = f.read() + res_before = from_markdown_table(content) + + with open(args.after, "r") as f: + content = f.read() + res_after = from_markdown_table(content) + + diff = defaultdict(defaultdict) + for model in res_before: + for task in res_before[model]: + mean_before, var_before = res_before[model][task] + if task not in res_after[model]: + diff[model][task] = (None, mean_before, var_before, None, None) + else: + mean_after, var_after = res_after[model][task] + diff[model][task] = (mean_before / mean_after, mean_before, var_before, mean_after, var_after) + for model in res_after: + for task in res_after[model]: + if task not in res_before[model]: + mean_after, var_after = res_after[model][task] + diff[model][task] = (None, None, None, mean_after, var_after) + + header = ("model", "task", "speedup", "mean (before)", "var (before)", "mean (after)", "var (after)") + out = to_markdown_table(diff, header=header) + + print(out) + if args.output: + with open(args.output, "w") as f: + f.write(out) + +if __name__ == "__main__": + main() diff --git a/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py b/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py new file mode 100644 index 00000000000..3eeda15f1af --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py @@ -0,0 +1,153 @@ +import torch +from torch.autograd import functional + +import time +from argparse import ArgumentParser +from collections import defaultdict +from typing import NamedTuple, Callable, List, Any + +import ppl_models +import vision_models +import audio_text_models + +from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType + +# Listing of the different tasks +FAST_TASKS_NO_DOUBLE_BACK = [ + "vjp", +] + +FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [ + "vhp", + "jvp", +] + +ALL_TASKS = FAST_TASKS + [ + "hvp", + "jacobian", + "hessian" +] + +DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"] + +# Model definition which contains: +# - name: a string with the model name. +# - getter: a function to get the model. It takes as input the device on which the model +# will run. It should return the forward function and the parameters (Tensors) used as +# input for the forward function. Note that the forward must *not* have any side effect. +# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model. +# - unsupported: the list of tasks that this model cannot run. +class ModelDef(NamedTuple): + name: str + getter: GetterType + tasks: List[str] + unsupported: List[str] + +MODELS = [ + ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []), + ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []), + ModelDef("detr", vision_models.get_detr, FAST_TASKS, []), + ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []), + ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []), + ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []), + ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS), + ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []), + ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []), +] + +def get_v_for(model: Callable, inp: InputsType, task: str) -> VType: + v: VType + + if task in ["vjp"]: + out = model(*inp) + v = torch.rand_like(out) + elif task in ["jvp", "hvp", "vhp"]: + if isinstance(inp, tuple): + v = tuple(torch.rand_like(i) for i in inp) + else: + v = torch.rand_like(inp) + else: + v = None + + return v + +def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None: + func = getattr(functional, task) + + if v is not None: + res = func(model, inp, v=v, strict=True) + else: + res = func(model, inp, strict=True) + +def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]: + if args.gpu == -1: + device = torch.device("cpu") + + def noop(): + pass + do_sync = noop + else: + device = torch.device("cuda:{}".format(args.gpu)) + do_sync = torch.cuda.synchronize + + model, inp = model_getter(device) + + v = get_v_for(model, inp, task) + # Warmup + run_once(model, inp, task, v) + + elapsed = [] + for it in range(args.num_iters): + do_sync() + start = time.time() + run_once(model, inp, task, v) + do_sync() + elapsed.append(time.time() - start) + + return elapsed + +def main(): + parser = ArgumentParser("Main script to benchmark functional API of the autograd.") + parser.add_argument("--output", type=str, default="", help="Text file where to write the output") + parser.add_argument("--num-iters", type=int, default=10) + parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect") + parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks") + parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter") + parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter") + parser.add_argument("--num-threads", type=int, default=10, + help="Number of concurrent threads to use when running on cpu") + parser.add_argument("--seed", type=int, default=0, help="The random seed to use.") + args = parser.parse_args() + + results: TimingResultType = defaultdict(defaultdict) + torch.set_num_threads(args.num_threads) + torch.set_num_interop_threads(args.num_threads) + + # This automatically seed cuda if it is available + torch.manual_seed(args.seed) + + if args.gpu == -2: + args.gpu = 0 if torch.cuda.is_available() else -1 + + for name, model_getter, recommended_tasks, unsupported_tasks in MODELS: + if args.model_filter and name not in args.model_filter: + continue + tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks + for task in tasks: + if task in unsupported_tasks: + continue + if args.task_filter and task not in args.task_filter: + continue + runtimes = run_model(model_getter, args, task) + + runtimes = torch.tensor(runtimes) + mean, var = runtimes.mean(), runtimes.var() + results[name][task] = (mean.item(), var.item()) + print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var)) + + if args.output: + with open(args.output, "w") as f: + f.write(to_markdown_table(results)) + +if __name__ == "__main__": + main() diff --git a/benchmarks/functional_autograd_benchmark/ppl_models.py b/benchmarks/functional_autograd_benchmark/ppl_models.py new file mode 100644 index 00000000000..906ebac5d41 --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/ppl_models.py @@ -0,0 +1,93 @@ +import torch +from torch import Tensor +import torch.distributions as dist + +from utils import GetterReturnType + +def get_simple_regression(device: torch.device) -> GetterReturnType: + N = 10 + K = 10 + + loc_beta = 0. + scale_beta = 1. + + beta_prior = dist.Normal(loc_beta, scale_beta) + + X = torch.rand(N, K + 1, device=device) + Y = torch.rand(N, 1, device=device) + + # X.shape: (N, K + 1), Y.shape: (N, 1), beta_value.shape: (K + 1, 1) + beta_value = beta_prior.sample((K + 1, 1)) + beta_value.requires_grad_(True) + + def forward(beta_value: Tensor) -> Tensor: + mu = X.mm(beta_value) + + # We need to compute the first and second gradient of this score with respect + # to beta_value. + score = dist.Bernoulli(logits=mu).log_prob(Y).sum() + beta_prior.log_prob(beta_value).sum() + return score + + return forward, (beta_value.to(device),) + + +def get_robust_regression(device: torch.device) -> GetterReturnType: + N = 10 + K = 10 + + # X.shape: (N, K + 1), Y.shape: (N, 1) + X = torch.rand(N, K + 1, device=device) + Y = torch.rand(N, 1, device=device) + + # Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1) + nu_alpha = torch.randn(1, 1, device=device) + nu_beta = torch.rand(1, 1, device=device) + nu = dist.Gamma(nu_alpha, nu_beta) + + # Predefined sigma_rate: sigma_rate.shape: (N, 1) + sigma_rate = torch.rand(N, 1, device=device) + sigma = dist.Exponential(sigma_rate) + + # Predefined beta_mean and beta_sigma: beta_mean.shape: (K + 1, 1), beta_sigma.shape: (K + 1, 1) + beta_mean = torch.rand(K + 1, 1, device=device) + beta_sigma = torch.rand(K + 1, 1, device=device) + beta = dist.Normal(beta_mean, beta_sigma) + + nu_value = nu.sample() + nu_value.requires_grad_(True) + + sigma_value = sigma.sample() + sigma_unconstrained_value = sigma_value.log() + sigma_unconstrained_value.requires_grad_(True) + + beta_value = beta.sample() + beta_value.requires_grad_(True) + + def forward(nu_value: Tensor, sigma_unconstrained_value: Tensor, beta_value: Tensor) -> Tensor: + sigma_constrained_value = sigma_unconstrained_value.exp() + mu = X.mm(beta_value) + + # For this model, we need to compute the following three scores: + # We need to compute the first and second gradient of this score with respect + # to nu_value. + nu_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \ + + nu.log_prob(nu_value) + + + + # We need to compute the first and second gradient of this score with respect + # to sigma_unconstrained_value. + sigma_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \ + + sigma.log_prob(sigma_constrained_value) \ + + sigma_unconstrained_value + + + + # We need to compute the first and second gradient of this score with respect + # to beta_value. + beta_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \ + + beta.log_prob(beta_value) + + return nu_score.sum() + sigma_score.sum() + beta_score.sum() + + return forward, (nu_value.to(device), sigma_unconstrained_value.to(device), beta_value.to(device)) diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py new file mode 100644 index 00000000000..1e4cc747b0f --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -0,0 +1,556 @@ +# Taken from https://github.com/pytorch/audio/blob/master/torchaudio/models/wav2letter.py +# So that we don't need torchaudio to be installed + +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +import math +from collections import OrderedDict +from typing import Tuple, Optional + +__all__ = ["Wav2Letter"] + + +class Wav2Letter(nn.Module): + r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System" + `_ paper. + :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` + Args: + num_classes (int, optional): Number of classes to be classified. (Default: ``40``) + input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` + or ``mfcc`` (Default: ``waveform``). + num_features (int, optional): Number of input features that the network will receive (Default: ``1``). + """ + + def __init__(self, num_classes: int = 40, + input_type: str = "waveform", + num_features: int = 1) -> None: + super(Wav2Letter, self).__init__() + + acoustic_num_features = 250 if input_type == "waveform" else num_features + acoustic_model = nn.Sequential( + nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True) + ) + + if input_type == "waveform": + waveform_model = nn.Sequential( + nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), + nn.ReLU(inplace=True) + ) + self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) + + if input_type in ["power_spectrum", "mfcc"]: + self.acoustic_model = acoustic_model + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (Tensor): Tensor of dimension (batch_size, num_features, input_length). + Returns: + Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). + """ + + x = self.acoustic_model(x) + x = nn.functional.log_softmax(x, dim=1) + return x + +# Taken from https://github.com/SeanNaren/deepspeech.pytorch with modifications +class SequenceWise(nn.Module): + def __init__(self, module): + """ + Collapses input of dim T*N*H to (T*N)*H, and applies to a module. + Allows handling of variable sequence lengths and minibatch sizes. + :param module: Module to apply input to. + """ + super(SequenceWise, self).__init__() + self.module = module + + def forward(self, x): + t, n = x.size(0), x.size(1) + x = x.view(t * n, -1) + x = self.module(x) + x = x.view(t, n, -1) + return x + + def __repr__(self): + tmpstr = self.__class__.__name__ + ' (\n' + tmpstr += self.module.__repr__() + tmpstr += ')' + return tmpstr + + +class MaskConv(nn.Module): + def __init__(self, seq_module): + """ + Adds padding to the output of the module based on the given lengths. This is to ensure that the + results of the model do not change when batch sizes change during inference. + Input needs to be in the shape of (BxCxDxT) + :param seq_module: The sequential module containing the conv stack. + """ + super(MaskConv, self).__init__() + self.seq_module = seq_module + + def forward(self, x, lengths): + """ + :param x: The input of size BxCxDxT + :param lengths: The actual length of each sequence in the batch + :return: Masked output from the module + """ + for module in self.seq_module: + x = module(x) + mask = torch.BoolTensor(x.size()).fill_(0) + if x.is_cuda: + mask = mask.cuda() + for i, length in enumerate(lengths): + length = length.item() + if (mask[i].size(2) - length) > 0: + mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) + x = x.masked_fill(mask, 0) + return x, lengths + + +class InferenceBatchSoftmax(nn.Module): + def forward(self, input_): + if not self.training: + return F.softmax(input_, dim=-1) + else: + return input_ + + +class BatchRNN(nn.Module): + def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): + super(BatchRNN, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bidirectional = bidirectional + self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None + self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, + bidirectional=bidirectional, bias=True) + self.num_directions = 2 if bidirectional else 1 + + def flatten_parameters(self): + self.rnn.flatten_parameters() + + def forward(self, x, output_lengths): + if self.batch_norm is not None: + x = self.batch_norm(x) + x = nn.utils.rnn.pack_padded_sequence(x, output_lengths, enforce_sorted=False) + x, h = self.rnn(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x) + if self.bidirectional: + x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum + return x + + +class Lookahead(nn.Module): + # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks + # input shape - sequence, batch, feature - TxNxH + # output shape - same as input + def __init__(self, n_features, context): + super(Lookahead, self).__init__() + assert context > 0 + self.context = context + self.n_features = n_features + self.pad = (0, self.context - 1) + self.conv = nn.Conv1d(self.n_features, self.n_features, kernel_size=self.context, stride=1, + groups=self.n_features, padding=0, bias=None) + + def forward(self, x): + x = x.transpose(0, 1).transpose(1, 2) + x = F.pad(x, pad=self.pad, value=0) + x = self.conv(x) + x = x.transpose(1, 2).transpose(0, 1).contiguous() + return x + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + 'n_features=' + str(self.n_features) \ + + ', context=' + str(self.context) + ')' + +class DeepSpeech(nn.Module): + def __init__(self, rnn_type, labels, rnn_hidden_size, nb_layers, audio_conf, + bidirectional, context=20): + super(DeepSpeech, self).__init__() + + self.hidden_size = rnn_hidden_size + self.hidden_layers = nb_layers + self.rnn_type = rnn_type + self.audio_conf = audio_conf + self.labels = labels + self.bidirectional = bidirectional + + sample_rate = self.audio_conf["sample_rate"] + window_size = self.audio_conf["window_size"] + num_classes = len(self.labels) + + self.conv = MaskConv(nn.Sequential( + nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True), + nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True) + )) + # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 + rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) + rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) + rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) + rnn_input_size *= 32 + + rnns = [] + rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, + bidirectional=bidirectional, batch_norm=False) + rnns.append(('0', rnn)) + for x in range(nb_layers - 1): + rnn = BatchRNN(input_size=rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, + bidirectional=bidirectional) + rnns.append(('%d' % (x + 1), rnn)) + self.rnns = nn.Sequential(OrderedDict(rnns)) + self.lookahead = nn.Sequential( + # consider adding batch norm? + Lookahead(rnn_hidden_size, context=context), + nn.Hardtanh(0, 20, inplace=True) + ) if not bidirectional else None + + fully_connected = nn.Sequential( + nn.BatchNorm1d(rnn_hidden_size), + nn.Linear(rnn_hidden_size, num_classes, bias=False) + ) + self.fc = nn.Sequential( + SequenceWise(fully_connected), + ) + self.inference_softmax = InferenceBatchSoftmax() + + def forward(self, x, lengths): + lengths = lengths.cpu().int() + output_lengths = self.get_seq_lens(lengths) + x, _ = self.conv(x, output_lengths) + + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension + x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH + + for rnn in self.rnns: + x = rnn(x, output_lengths) + + if not self.bidirectional: # no need for lookahead layer in bidirectional + x = self.lookahead(x) + + x = self.fc(x) + x = x.transpose(0, 1) + # identity in training mode, softmax in eval mode + x = self.inference_softmax(x) + return x, output_lengths + + def get_seq_lens(self, input_length): + """ + Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable + containing the size sequences that will be output by the network. + :param input_length: 1D Tensor + :return: 1D Tensor scaled by model + """ + seq_len = input_length + for m in self.conv.modules(): + if type(m) == nn.modules.conv.Conv2d: + seq_len = seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1 + seq_len = seq_len.true_divide(m.stride[1]) + 1 + return seq_len.int() + +# Taken from https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L108-L152 +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + +class TransformerModel(nn.Module): + """Container module with an encoder, a recurrent or transformer module, and a decoder.""" + + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): + super(TransformerModel, self).__init__() + try: + from torch.nn import TransformerEncoder, TransformerEncoderLayer + except Exception: + raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') + self.model_type = 'Transformer' + self.src_mask = None + self.pos_encoder = PositionalEncoding(ninp, dropout) + encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.encoder = nn.Embedding(ntoken, ninp) + self.ninp = ninp + self.decoder = nn.Linear(ninp, ntoken) + + self.init_weights() + + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def init_weights(self): + initrange = 0.1 + nn.init.uniform_(self.encoder.weight, -initrange, initrange) + # Not sure how this works in the original code + # nn.init.zeros_(self.decoder) + nn.init.uniform_(self.decoder.weight, -initrange, initrange) + + def forward(self, src, has_mask=True): + if has_mask: + device = src.device + # This will be created once during warmup + if self.src_mask is None or self.src_mask.size(0) != len(src): + mask = self._generate_square_subsequent_mask(len(src)).to(device) + self.src_mask = mask + else: + self.src_mask = None + + src = self.encoder(src) * math.sqrt(self.ninp) + src = self.pos_encoder(src) + output = self.transformer_encoder(src, self.src_mask) + output = self.decoder(output) + return F.log_softmax(output, dim=-1) + +# From https://github.com/pytorch/text/blob/master/torchtext/modules +class MultiheadAttentionContainer(torch.nn.Module): + def __init__(self, nhead, in_proj_container, attention_layer, out_proj): + r""" A multi-head attention container + Args: + nhead: the number of heads in the multiheadattention model + in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). + attention_layer: The attention layer. + out_proj: The multi-head out-projection layer (a.k.a nn.Linear). + Examples:: + >>> import torch + >>> embed_dim, num_heads, bsz = 10, 5, 64 + >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)) + >>> MHA = MultiheadAttentionContainer(num_heads, + in_proj_container, + ScaledDotProduct(), + torch.nn.Linear(embed_dim, embed_dim)) + >>> query = torch.rand((21, bsz, embed_dim)) + >>> key = value = torch.rand((16, bsz, embed_dim)) + >>> attn_output, attn_weights = MHA(query, key, value) + >>> print(attn_output.shape) + >>> torch.Size([21, 64, 10]) + """ + super(MultiheadAttentionContainer, self).__init__() + self.nhead = nhead + self.in_proj_container = in_proj_container + self.attention_layer = attention_layer + self.out_proj = out_proj + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + bias_k: Optional[torch.Tensor] = None, + bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + query, key, value (Tensor): map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + attn_mask, bias_k and bias_v (Tensor, optional): keyword arguments passed to the attention layer. + See the definitions in the attention. + Shape: + - Inputs: + - query: :math:`(L, N, E)` + - key: :math:`(S, N, E)` + - value: :math:`(S, N, E)` + - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. + - Outputs: + - attn_output: :math:`(L, N, E)` + - attn_output_weights: :math:`(N * H, L, S)` + where where L is the target length, S is the sequence length, H is the number of attention heads, + N is the batch size, and E is the embedding dimension. + """ + tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1) + q, k, v = self.in_proj_container(query, key, value) + assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" + head_dim = q.size(-1) // self.nhead + q = q.reshape(tgt_len, bsz * self.nhead, head_dim) + + assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads" + head_dim = k.size(-1) // self.nhead + k = k.reshape(src_len, bsz * self.nhead, head_dim) + + assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads" + head_dim = v.size(-1) // self.nhead + v = v.reshape(src_len, bsz * self.nhead, head_dim) + + attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, + bias_k=bias_k, bias_v=bias_v) + attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, attn_output_weights + + +class ScaledDotProduct(torch.nn.Module): + + def __init__(self, dropout=0.0): + r"""Processes a projected query and key-value pair to apply + scaled dot product attention. + Args: + dropout (float): probability of dropping an attention weight. + Examples:: + >>> SDP = torchtext.models.ScaledDotProduct(0.1) + >>> q = torch.randn(256, 21, 3) + >>> k = v = torch.randn(256, 21, 3) + >>> attn_output, attn_weights = SDP(q, k, v) + >>> print(attn_output.shape, attn_weights.shape) + torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) + """ + super(ScaledDotProduct, self).__init__() + self.dropout = dropout + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + bias_k: Optional[torch.Tensor] = None, + bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Uses a scaled dot product with the projected key-value pair to update + the projected query. + Args: + query (Tensor): Projected query + key (Tensor): Projected key + value (Tensor): Projected value + attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. + bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at + sequence dim (dim=-3). Those are used for incremental decoding. Users should provide + non-None to both arguments in order to activate them. + Shape: + - query: :math:`(L, N * H, E / H)` + - key: :math:`(S, N * H, E / H)` + - value: :math:`(S, N * H, E / H)` + - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend + while ``False`` values will be unchanged. + - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` + - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` + where L is the target length, S is the source length, H is the number + of attention heads, N is the batch size, and E is the embedding dimension. + """ + if bias_k is not None and bias_v is not None: + assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \ + "Shape of bias_k is not supported" + assert value.size(-1) == bias_v.size(-1) and value.size(-2) == bias_v.size(-2) and bias_v.size(-3) == 1, \ + "Shape of bias_v is not supported" + key = torch.cat([key, bias_k]) + value = torch.cat([value, bias_v]) + if attn_mask is not None: + _attn_mask = attn_mask + attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1]) + + tgt_len, head_dim = query.size(-3), query.size(-1) + assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." + assert key.size() == value.size(), "Shape of key, value must match" + src_len = key.size(-3) + batch_heads = max(query.size(-2), key.size(-2)) + + # Scale query + query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3) + query = query * (float(head_dim) ** -0.5) + if attn_mask is not None: + if attn_mask.dim() != 3: + raise RuntimeError('attn_mask must be a 3D tensor.') + if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \ + (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads): + raise RuntimeError('The size of the attn_mask is not correct.') + if attn_mask.dtype != torch.bool: + raise RuntimeError('Only bool tensor is supported for attn_mask') + + # Dot product of q, k + attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) + if attn_mask is not None: + attn_output_weights.masked_fill_(attn_mask, -1e8,) + attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_output_weights, value) + return attn_output.transpose(-2, -3), attn_output_weights + + +class InProjContainer(torch.nn.Module): + def __init__(self, query_proj, key_proj, value_proj): + r"""A in-proj container to process inputs. + Args: + query_proj: a proj layer for query. + key_proj: a proj layer for key. + value_proj: a proj layer for value. + """ + + super(InProjContainer, self).__init__() + self.query_proj = query_proj + self.key_proj = key_proj + self.value_proj = value_proj + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Projects the input sequences using in-proj layers. + Args: + query, key, value (Tensors): sequence to be projected + Shape: + - query, key, value: :math:`(S, N, E)` + - Output: :math:`(S, N, E)` + where S is the sequence length, N is the batch size, and E is the embedding dimension. + """ + return self.query_proj(query), self.key_proj(key), self.value_proj(value) diff --git a/benchmarks/functional_autograd_benchmark/torchvision_models.py b/benchmarks/functional_autograd_benchmark/torchvision_models.py new file mode 100644 index 00000000000..25361af7766 --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/torchvision_models.py @@ -0,0 +1,803 @@ +# Taken from https://github.com/pytorch/vision +# So that we don't need torchvision to be installed +import torch +from torch import nn +from torch.nn import functional as F + +from torch.jit.annotations import Dict +from collections import OrderedDict + +try: + from scipy.optimize import linear_sum_assignment # type: ignore + scipy_available = True +except Exception: + scipy_available = False + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + # if pretrained: + # state_dict = load_state_dict_from_url(model_urls[arch], + # progress=progress) + # model.load_state_dict(state_dict) + return model + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + Examples:: + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + _version = 2 + __annotations__ = { + "return_layers": Dict[str, str], + } + + def __init__(self, model, return_layers): + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + orig_return_layers = return_layers + return_layers = {str(k): str(v) for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out + +class _SimpleSegmentationModel(nn.Module): + __constants__ = ['aux_classifier'] + + def __init__(self, backbone, classifier, aux_classifier=None): + super(_SimpleSegmentationModel, self).__init__() + self.backbone = backbone + self.classifier = classifier + self.aux_classifier = aux_classifier + + def forward(self, x): + input_shape = x.shape[-2:] + # contract: features is a dict of tensors + features = self.backbone(x) + + result = OrderedDict() + x = features["out"] + x = self.classifier(x) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + result["out"] = x + + if self.aux_classifier is not None: + x = features["aux"] + x = self.aux_classifier(x) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + result["aux"] = x + + return result + +class FCN(_SimpleSegmentationModel): + """ + Implements a Fully-Convolutional Network for semantic segmentation. + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + pass + +class FCNHead(nn.Sequential): + def __init__(self, in_channels, channels): + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv2d(inter_channels, channels, 1) + ] + + super(FCNHead, self).__init__(*layers) + +def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): + # backbone = resnet.__dict__[backbone_name]( + # pretrained=pretrained_backbone, + # replace_stride_with_dilation=[False, True, True]) + # Hardcoded resnet 50 + assert backbone_name == "resnet50" + backbone = resnet50( + pretrained=pretrained_backbone, + replace_stride_with_dilation=[False, True, True]) + + return_layers = {'layer4': 'out'} + if aux: + return_layers['layer3'] = 'aux' + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = None + if aux: + inplanes = 1024 + aux_classifier = FCNHead(inplanes, num_classes) + + model_map = { + # 'deeplabv3': (DeepLabHead, DeepLabV3), # Not used + 'fcn': (FCNHead, FCN), + } + inplanes = 2048 + classifier = model_map[name][0](inplanes, num_classes) + base_model = model_map[name][1] + + model = base_model(backbone, classifier, aux_classifier) + return model + +def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): + if pretrained: + aux_loss = True + model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) + # if pretrained: + # arch = arch_type + '_' + backbone + '_coco' + # model_url = model_urls[arch] + # if model_url is None: + # raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + # else: + # state_dict = load_state_dict_from_url(model_url, progress=progress) + # model.load_state_dict(state_dict) + return model + +def fcn_resnet50(pretrained=False, progress=True, + num_classes=21, aux_loss=None, **kwargs): + """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + + +# Taken from @fmassa example slides and https://github.com/facebookresearch/detr +class DETR(nn.Module): + """ + Demo DETR implementation. + + Demo implementation of DETR in minimal number of lines, with the + following differences wrt DETR in the paper: + * learned positional encoding (instead of sine) + * positional encoding is passed at input (instead of attention) + * fc bbox predictor (instead of MLP) + The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100. + Only batch size 1 supported. + """ + def __init__(self, num_classes, hidden_dim=256, nheads=8, + num_encoder_layers=6, num_decoder_layers=6): + super().__init__() + + # create ResNet-50 backbone + self.backbone = resnet50() + del self.backbone.fc + + # create conversion layer + self.conv = nn.Conv2d(2048, hidden_dim, 1) + + # create a default PyTorch transformer + self.transformer = nn.Transformer( + hidden_dim, nheads, num_encoder_layers, num_decoder_layers) + + # prediction heads, one extra class for predicting non-empty slots + # note that in baseline DETR linear_bbox layer is 3-layer MLP + self.linear_class = nn.Linear(hidden_dim, num_classes + 1) + self.linear_bbox = nn.Linear(hidden_dim, 4) + + # output positional encodings (object queries) + self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) + + # spatial positional encodings + # note that in baseline DETR we use sine positional encodings + self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) + self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) + + def forward(self, inputs): + # propagate inputs through ResNet-50 up to avg-pool layer + x = self.backbone.conv1(inputs) + x = self.backbone.bn1(x) + x = self.backbone.relu(x) + x = self.backbone.maxpool(x) + + x = self.backbone.layer1(x) + x = self.backbone.layer2(x) + x = self.backbone.layer3(x) + x = self.backbone.layer4(x) + + # convert from 2048 to 256 feature planes for the transformer + h = self.conv(x) + + # construct positional encodings + H, W = h.shape[-2:] + pos = torch.cat([ + self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), + self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), + ], dim=-1).flatten(0, 1).unsqueeze(1) + + # propagate through the transformer + # TODO (alband) Why this is not automatically broadcasted? (had to add the repeat) + f = pos + 0.1 * h.flatten(2).permute(2, 0, 1) + s = self.query_pos.unsqueeze(1) + s = s.expand(s.size(0), inputs.size(0), s.size(2)) + h = self.transformer(f, s).transpose(0, 1) + + # finally project transformer outputs to class labels and bounding boxes + return {'pred_logits': self.linear_class(h), + 'pred_boxes': self.linear_bbox(h).sigmoid()} + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + The boxes should be in [x0, y0, x1, y1] format + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + +def box_area(boxes): + """ + Computes the area of a set of bounding boxes, which are specified by its + (x1, y1, x2, y2) coordinates. + Arguments: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (x1, y1, x2, y2) format + Returns: + area (Tensor[N]): area for each box + """ + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + +def is_dist_avail_and_initialized(): + return False + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), + box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): + """Creates the matcher + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + if not scipy_available: + raise RuntimeError("The 'detr' model requires scipy to run. Please make sure you have it installed" + " if you enable the 'detr' model.") + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] diff --git a/benchmarks/functional_autograd_benchmark/utils.py b/benchmarks/functional_autograd_benchmark/utils.py new file mode 100644 index 00000000000..c7aeb29d157 --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/utils.py @@ -0,0 +1,103 @@ +import torch + +from collections import defaultdict + +from torch import nn, Tensor +from typing import List, Tuple, Dict, Union, Callable + +# Type helpers +InputsType = Union[Tensor, Tuple[Tensor, ...]] +# A Getter takes in a device and returns a callable and the inputs to that callable +GetterReturnType = Tuple[Callable[..., Tensor], InputsType] +GetterType = Callable[[torch.device], GetterReturnType] +# V here refers to the v in either vjp, jvp, vhp or hvp +VType = Union[None, Tensor, Tuple[Tensor, ...]] +# Type used to store timing results. The first key is the model name, the second key +# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after. +TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]] + +# Utilities to make nn.Module "functional" +# In particular the goal is to be able to provide a function that takes as input +# the parameters and evaluate the nn.Module using fixed inputs. +def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: + """ + Deletes the attribute specified by the given list of names. + For example, to delete the attribute obj.conv.weight, + use _del_nested_attr(obj, ['conv', 'weight']) + """ + if len(names) == 1: + delattr(obj, names[0]) + else: + _del_nested_attr(getattr(obj, names[0]), names[1:]) + +def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: + """ + Set the attribute specified by the given list of names to value. + For example, to set the attribute obj.conv.weight, + use _del_nested_attr(obj, ['conv', 'weight'], value) + """ + if len(names) == 1: + setattr(obj, names[0], value) + else: + _set_nested_attr(getattr(obj, names[0]), names[1:], value) + +def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: + """ + This function removes all the Parameters from the model and + return them as a tuple as well as their original attribute names. + The weights must be re-loaded with `load_weights` before the model + can be used again. + Note that this function modifies the model in place and after this + call, mod.parameters() will be empty. + """ + orig_params = tuple(mod.parameters()) + # Remove all the parameters in the model + names = [] + for name, p in list(mod.named_parameters()): + _del_nested_attr(mod, name.split(".")) + names.append(name) + + # Make params regular Tensors instead of nn.Parameter + params = tuple(p.detach().requires_grad_() for p in orig_params) + return params, names + +def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None: + """ + Reload a set of weights so that `mod` can be used again to perform a forward pass. + Note that the `params` are regular Tensors (that can have history) and so are left + as Tensors. This means that mod.parameters() will still be empty after this call. + """ + for name, p in zip(names, params): + _set_nested_attr(mod, name.split("."), p) + +# Utilities to read/write markdown table-like content. +def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str: + if header is None: + header = ("model", "task", "mean", "var") + out = "" + + def write_line(*args): + nonlocal out + out += "| {} |\n".format(" | ".join(str(a) for a in args)) + + # Make it a markdown table + write_line(*header) + write_line(*["--"] * len(header)) + for model, tasks in res.items(): + for task, line in tasks.items(): + write_line(*(model, task) + line) + + return out + +def from_markdown_table(data: str) -> TimingResultType: + out = data.strip().split("\n") + out = out[2:] # Ignore the header lines + + res: TimingResultType + res = defaultdict(defaultdict) + + for line in out: + model, task, mean, var = [f.strip() for f in line.strip().split("|") if f] + res[model][task] = (float(mean), float(var)) + + return res diff --git a/benchmarks/functional_autograd_benchmark/vision_models.py b/benchmarks/functional_autograd_benchmark/vision_models.py new file mode 100644 index 00000000000..cd2f84e638a --- /dev/null +++ b/benchmarks/functional_autograd_benchmark/vision_models.py @@ -0,0 +1,97 @@ +import torch +from torch import Tensor +import torchvision_models as models + +from utils import extract_weights, load_weights, GetterReturnType + +from typing import cast + +def get_resnet18(device: torch.device) -> GetterReturnType: + N = 32 + model = models.resnet18(pretrained=False) + criterion = torch.nn.CrossEntropyLoss() + model.to(device) + params, names = extract_weights(model) + + inputs = torch.rand([N, 3, 224, 224], device=device) + labels = torch.rand(N, device=device).mul(10).long() + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out = model(inputs) + + loss = criterion(out, labels) + return loss + + return forward, params + +def get_fcn_resnet(device: torch.device) -> GetterReturnType: + N = 8 + criterion = torch.nn.MSELoss() + model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False) + model.to(device) + params, names = extract_weights(model) + + inputs = torch.rand([N, 3, 480, 480], device=device) + # Given model has 21 classes + labels = torch.rand([N, 21, 480, 480], device=device) + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out = model(inputs)['out'] + + loss = criterion(out, labels) + return loss + + return forward, params + +def get_detr(device: torch.device) -> GetterReturnType: + # All values below are from CLI defaults in https://github.com/facebookresearch/detr + N = 2 + num_classes = 91 + hidden_dim = 256 + nheads = 8 + num_encoder_layers = 6 + num_decoder_layers = 6 + + model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads, + num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers) + losses = ['labels', 'boxes', 'cardinality'] + eos_coef = 0.1 + bbox_loss_coef = 5 + giou_loss_coef = 2 + weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef, 'loss_giou': giou_loss_coef} + matcher = models.HungarianMatcher(1, 5, 2) + criterion = models.SetCriterion(num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, + eos_coef=eos_coef, losses=losses) + + model = model.to(device) + criterion = criterion.to(device) + params, names = extract_weights(model) + + inputs = torch.rand(N, 3, 800, 1200, device=device) + labels = [] + for idx in range(N): + targets = {} + n_targets: int = int(torch.randint(5, 10, size=tuple()).item()) + label = torch.randint(5, 10, size=(n_targets,)) + targets["labels"] = label + boxes = torch.randint(100, 800, size=(n_targets, 4)) + for t in range(n_targets): + if boxes[t, 0] > boxes[t, 2]: + boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0] + if boxes[t, 1] > boxes[t, 3]: + boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1] + targets["boxes"] = boxes.float() + labels.append(targets) + + def forward(*new_params: Tensor) -> Tensor: + load_weights(model, names, new_params) + out = model(inputs) + + loss = criterion(out, labels) + weight_dict = criterion.weight_dict + final_loss = cast(Tensor, sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict)) + return final_loss + + return forward, params diff --git a/test/run_test.py b/test/run_test.py index da5c9f56b29..f2bac98d000 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -87,6 +87,7 @@ TESTS = [ 'test_determination', 'test_futures', 'test_fx', + 'test_functional_autograd_benchmark' ] WINDOWS_BLOCKLIST = [ diff --git a/test/test_functional_autograd_benchmark.py b/test/test_functional_autograd_benchmark.py new file mode 100644 index 00000000000..8c8e06754b6 --- /dev/null +++ b/test/test_functional_autograd_benchmark.py @@ -0,0 +1,57 @@ +from torch.testing._internal.common_utils import TestCase, run_tests, slowTest, IS_WINDOWS + +import subprocess +import tempfile +import os +import unittest + +# This is a very simple smoke test for the functional autograd benchmarking script. +class TestFunctionalAutogradBenchmark(TestCase): + def _test_runner(self, model, disable_gpu=False): + # Note about windows: + # The temporary file is exclusively open by this process and the child process + # is not allowed to open it again. As this is a simple smoke test, we choose for now + # not to run this on windows and keep the code here simple. + with tempfile.NamedTemporaryFile() as out_file: + cmd = ['python', '../benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py'] + # Only run the warmup + cmd += ['--num-iters', '0'] + # Only run the vjp task (fastest one) + cmd += ['--task-filter', 'vjp'] + # Only run the specified model + cmd += ['--model-filter', model] + # Output file + cmd += ['--output', out_file.name] + if disable_gpu: + cmd += ['--gpu', '-1'] + + res = subprocess.run(cmd) + + self.assertTrue(res.returncode == 0) + # Check that something was written to the file + out_file.seek(0, os.SEEK_END) + self.assertTrue(out_file.tell() > 0) + + + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.") + def test_fast_tasks(self): + fast_tasks = ['resnet18', 'ppl_simple_reg', 'ppl_robust_reg', 'wav2letter', + 'transformer', 'multiheadattn'] + + for task in fast_tasks: + self._test_runner(task) + + @slowTest + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.") + def test_slow_tasks(self): + slow_tasks = ['fcn_resnet', 'detr'] + # deepspeech is voluntarily excluded as it takes too long to run without + # proper tuning of the number of threads it should use. + + for task in slow_tasks: + # Disable GPU for slow test as the CI GPU don't have enough memory + self._test_runner(task, disable_gpu=True) + + +if __name__ == '__main__': + run_tests()