[quant] Refactor nonuniform quantization mapping functions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79790

Approved by: https://github.com/dzdang
This commit is contained in:
asl3
2022-06-17 16:10:58 -07:00
committed by PyTorch MergeBot
parent 660d9ddef4
commit 228e082ca9
3 changed files with 54 additions and 26 deletions

View File

@@ -58,6 +58,9 @@ ignore_missing_imports = True
[mypy-torch.for_onnx.onnx]
ignore_missing_imports = True
[mypy-torch.ao.quantization.experimental.apot_utils]
ignore_missing_imports = True
#
# Files with various errors. Mostly real errors, possibly some false
# positives as well.

View File

@@ -0,0 +1,50 @@
r"""
This file contains utility functions to convert values
using APoT nonuniform quantization methods.
"""
import math
r"""Converts floating point input into int4 APoT2 number
based on quantization levels
"""
def float_to_apot(x, levels, indices):
levels_lst = list(levels)
indices_lst = list(indices)
min_delta = math.inf
best_idx = 0
for level, idx in zip(levels_lst, indices_lst):
cur_delta = abs(level - x)
if cur_delta < min_delta:
min_delta = cur_delta
best_idx = idx
return best_idx
r"""Converts floating point input into
reduced precision floating point value
based on quantization levels
"""
def float_to_reduced_precision(x, levels, indices):
levels_lst = list(levels)
indices_lst = list(indices)
min_delta = math.inf
best_fp = 0.0
for level, idx in zip(levels_lst, indices_lst):
cur_delta = abs(level - x)
if cur_delta < min_delta:
min_delta = cur_delta
best_fp = level
return best_fp
r"""Converts int4 APoT2 input into floating point number
based on quantization levels
"""
def apot_to_float(x_apot, levels, indices):
idx = list(indices).index(x_apot)
return levels[idx]

View File

@@ -5,9 +5,9 @@ the values observed during calibration (PTQ) or training (QAT).
import torch
import itertools
import math
import matplotlib.pyplot as plt
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# TODO: Consider adding NonUniformQuantizationObserverBase class
# when more than one non-uniform method is implemented
@@ -128,28 +128,3 @@ class APoTObserver(ObserverBase):
plt.xlabel("Full Precision")
plt.ylabel("Quantized")
plt.show()
r"""Converts floating point input into int4 APoT2 number
based on quantization levels
"""
def float_to_apot(x, levels, indices):
levels_lst = list(levels)
indices_lst = list(indices)
min_delta = math.inf
best_idx = 0
for level, idx in zip(levels_lst, indices_lst):
cur_delta = abs(level - x)
if cur_delta < min_delta:
min_delta = cur_delta
best_idx = idx
return best_idx
r"""Converts int4 APoT2 input into floating point number
based on quantization levels
"""
def apot_to_float(x_apot, levels, indices):
idx = list(indices).index(x_apot)
return levels[idx]