From 228e082ca9cef631f9fd73e43d94b90c59b451c1 Mon Sep 17 00:00:00 2001 From: asl3 Date: Fri, 17 Jun 2022 16:10:58 -0700 Subject: [PATCH] [quant] Refactor nonuniform quantization mapping functions Pull Request resolved: https://github.com/pytorch/pytorch/pull/79790 Approved by: https://github.com/dzdang --- mypy.ini | 3 ++ .../quantization/experimental/apot_utils.py | 50 +++++++++++++++++++ .../ao/quantization/experimental/observer.py | 27 +--------- 3 files changed, 54 insertions(+), 26 deletions(-) create mode 100644 torch/ao/quantization/experimental/apot_utils.py diff --git a/mypy.ini b/mypy.ini index 61442c1a7d6..f57a08022f1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -58,6 +58,9 @@ ignore_missing_imports = True [mypy-torch.for_onnx.onnx] ignore_missing_imports = True +[mypy-torch.ao.quantization.experimental.apot_utils] +ignore_missing_imports = True + # # Files with various errors. Mostly real errors, possibly some false # positives as well. diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py new file mode 100644 index 00000000000..bac48c85d5d --- /dev/null +++ b/torch/ao/quantization/experimental/apot_utils.py @@ -0,0 +1,50 @@ +r""" +This file contains utility functions to convert values +using APoT nonuniform quantization methods. +""" + +import math + +r"""Converts floating point input into int4 APoT2 number + based on quantization levels +""" +def float_to_apot(x, levels, indices): + levels_lst = list(levels) + indices_lst = list(indices) + + min_delta = math.inf + best_idx = 0 + + for level, idx in zip(levels_lst, indices_lst): + cur_delta = abs(level - x) + if cur_delta < min_delta: + min_delta = cur_delta + best_idx = idx + + return best_idx + +r"""Converts floating point input into + reduced precision floating point value + based on quantization levels +""" +def float_to_reduced_precision(x, levels, indices): + levels_lst = list(levels) + indices_lst = list(indices) + + min_delta = math.inf + best_fp = 0.0 + + for level, idx in zip(levels_lst, indices_lst): + cur_delta = abs(level - x) + if cur_delta < min_delta: + min_delta = cur_delta + best_fp = level + + return best_fp + +r"""Converts int4 APoT2 input into floating point number +based on quantization levels +""" +def apot_to_float(x_apot, levels, indices): + idx = list(indices).index(x_apot) + return levels[idx] diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 8df7efc11d9..85313f646ce 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -5,9 +5,9 @@ the values observed during calibration (PTQ) or training (QAT). import torch import itertools -import math import matplotlib.pyplot as plt from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float # TODO: Consider adding NonUniformQuantizationObserverBase class # when more than one non-uniform method is implemented @@ -128,28 +128,3 @@ class APoTObserver(ObserverBase): plt.xlabel("Full Precision") plt.ylabel("Quantized") plt.show() - -r"""Converts floating point input into int4 APoT2 number - based on quantization levels -""" -def float_to_apot(x, levels, indices): - levels_lst = list(levels) - indices_lst = list(indices) - - min_delta = math.inf - best_idx = 0 - - for level, idx in zip(levels_lst, indices_lst): - cur_delta = abs(level - x) - if cur_delta < min_delta: - min_delta = cur_delta - best_idx = idx - - return best_idx - -r"""Converts int4 APoT2 input into floating point number -based on quantization levels -""" -def apot_to_float(x_apot, levels, indices): - idx = list(indices).index(x_apot) - return levels[idx]