mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[quant] Refactor nonuniform quantization mapping functions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79790 Approved by: https://github.com/dzdang
This commit is contained in:
3
mypy.ini
3
mypy.ini
@@ -58,6 +58,9 @@ ignore_missing_imports = True
|
||||
[mypy-torch.for_onnx.onnx]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch.ao.quantization.experimental.apot_utils]
|
||||
ignore_missing_imports = True
|
||||
|
||||
#
|
||||
# Files with various errors. Mostly real errors, possibly some false
|
||||
# positives as well.
|
||||
|
||||
50
torch/ao/quantization/experimental/apot_utils.py
Normal file
50
torch/ao/quantization/experimental/apot_utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
r"""
|
||||
This file contains utility functions to convert values
|
||||
using APoT nonuniform quantization methods.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
r"""Converts floating point input into int4 APoT2 number
|
||||
based on quantization levels
|
||||
"""
|
||||
def float_to_apot(x, levels, indices):
|
||||
levels_lst = list(levels)
|
||||
indices_lst = list(indices)
|
||||
|
||||
min_delta = math.inf
|
||||
best_idx = 0
|
||||
|
||||
for level, idx in zip(levels_lst, indices_lst):
|
||||
cur_delta = abs(level - x)
|
||||
if cur_delta < min_delta:
|
||||
min_delta = cur_delta
|
||||
best_idx = idx
|
||||
|
||||
return best_idx
|
||||
|
||||
r"""Converts floating point input into
|
||||
reduced precision floating point value
|
||||
based on quantization levels
|
||||
"""
|
||||
def float_to_reduced_precision(x, levels, indices):
|
||||
levels_lst = list(levels)
|
||||
indices_lst = list(indices)
|
||||
|
||||
min_delta = math.inf
|
||||
best_fp = 0.0
|
||||
|
||||
for level, idx in zip(levels_lst, indices_lst):
|
||||
cur_delta = abs(level - x)
|
||||
if cur_delta < min_delta:
|
||||
min_delta = cur_delta
|
||||
best_fp = level
|
||||
|
||||
return best_fp
|
||||
|
||||
r"""Converts int4 APoT2 input into floating point number
|
||||
based on quantization levels
|
||||
"""
|
||||
def apot_to_float(x_apot, levels, indices):
|
||||
idx = list(indices).index(x_apot)
|
||||
return levels[idx]
|
||||
@@ -5,9 +5,9 @@ the values observed during calibration (PTQ) or training (QAT).
|
||||
|
||||
import torch
|
||||
import itertools
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.ao.quantization.observer import ObserverBase
|
||||
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
|
||||
|
||||
# TODO: Consider adding NonUniformQuantizationObserverBase class
|
||||
# when more than one non-uniform method is implemented
|
||||
@@ -128,28 +128,3 @@ class APoTObserver(ObserverBase):
|
||||
plt.xlabel("Full Precision")
|
||||
plt.ylabel("Quantized")
|
||||
plt.show()
|
||||
|
||||
r"""Converts floating point input into int4 APoT2 number
|
||||
based on quantization levels
|
||||
"""
|
||||
def float_to_apot(x, levels, indices):
|
||||
levels_lst = list(levels)
|
||||
indices_lst = list(indices)
|
||||
|
||||
min_delta = math.inf
|
||||
best_idx = 0
|
||||
|
||||
for level, idx in zip(levels_lst, indices_lst):
|
||||
cur_delta = abs(level - x)
|
||||
if cur_delta < min_delta:
|
||||
min_delta = cur_delta
|
||||
best_idx = idx
|
||||
|
||||
return best_idx
|
||||
|
||||
r"""Converts int4 APoT2 input into floating point number
|
||||
based on quantization levels
|
||||
"""
|
||||
def apot_to_float(x_apot, levels, indices):
|
||||
idx = list(indices).index(x_apot)
|
||||
return levels[idx]
|
||||
|
||||
Reference in New Issue
Block a user