support fp16 shgemm under openblas (#169042)

# Purpose
This PR is to support fp16 shgemm under openblas. We conducted tests using vLLM on the following platform.
With using this patch, vLLM demonstrates faster inference speed under fp16.

**Platform info:**
Architecture:             riscv64
  Byte Order:             Little Endian
CPU(s):                   64
  On-line CPU(s) list:    0-63
Vendor ID:                0x5b7
  BIOS Vendor ID:         SOPHGO
  Model name:             -
    BIOS Model name:      SG2044 Not Set CPU @ 2.6GHz
    BIOS CPU family:      513
    CPU family:           0x80000000090c0d00
    Model:                0x2047000
    Thread(s) per core:   1
    Core(s) per socket:   64
    Socket(s):            1
    Frequency boost:      disabled
    CPU(s) scaling MHz:   100%
    CPU max MHz:          2600.0000
    CPU min MHz:          1000.0000
Caches (sum of all):
  L1d:                    4 MiB (64 instances)
  L1i:                    4 MiB (64 instances)
  L2:                     32 MiB (16 instances)
  L3:                     64 MiB (1 instance)
Vulnerabilities:
  Gather data sampling:   Not affected
  Itlb multihit:          Not affected
  L1tf:                   Not affected
  Mds:                    Not affected
  Meltdown:               Not affected
  Mmio stale data:        Not affected
  Reg file data sampling: Not affected
  Retbleed:               Not affected

ISA: rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintntl_zihintpause_zihpm_zawrs_zfa_zfh_zfhmin_zca_zcb_zcd_zba_zbb_zbc_zbs_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_sscofpmf_sstc_svinval_svnapot_svpbmt

**Branch**
openblas: develop
torch: develop
vllm: main

# Test Plan
Base: without this PR
Pytorch use OpenBLAS FP16 GEMM: use this PR

**Base**
export VLLM_CPU_OMP_THREADS_BIND=0-63
export VLLM_CPU_KVCACHE_SPACE=60

vllm bench latency \
  --model /home/models/Qwen2.5-7B-Instruct \
  --tensor-parallel-size 1\
  --dtype float16 \
  --input-len 16 \
  --output-len 16 \
  --enforce-eager \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --batch-size 1 \
  --n 1 \
  --num-iters-warmup 5 \
  --num-iters 8 \
  --seed 42 \
  --output-json ./latency_results_fp16_latency_base.json

**Pytorch use OpenBLAS FP16 GEMM**
export VLLM_CPU_OMP_THREADS_BIND=0-63
export VLLM_CPU_KVCACHE_SPACE=60

vllm bench latency \
  --model /home/models/Qwen2.5-7B-Instruct \
  --tensor-parallel-size 1\
  --dtype float16 \
  --input-len 16 \
  --output-len 16 \
  --enforce-eager \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --batch-size 1 \
  --n 1 \
  --num-iters-warmup 5 \
  --num-iters 8 \
  --seed 42 \
  --output-json ./latency_results_fp16_latency_with_openblas_support.json

# Result
**Base**
{
    "avg_latency": 62.53946338250171,
    "latencies": [
        58.46783778001554,
        58.230652199999895,
        58.335780619992875,
        59.77051957999356,
        58.587668860011036,
        59.31567866000114,
        58.460076240007766,
        89.14749311999185
    ],
    "percentiles": {
        "10": 58.30424209399498,
        "25": 58.42900233500404,
        "50": 58.52775332001329,
        "75": 59.429388889999245,
        "90": 68.58361164199304,
        "99": 87.09110497219196
    }
}

**Pytorch use OpenBLAS FP16 GEMM**
{
    "avg_latency": 32.42863222499727,
    "latencies": [
        30.742418120033108,
        33.67000828002347,
        29.747197599965148,
        32.11275753995869,
        34.566938299976755,
        30.849812360014766,
        34.46360486000776,
        33.27632073999848
    ],
    "percentiles": {
        "10": 30.44385196401272,
        "25": 30.82296380001935,
        "50": 32.69453913997859,
        "75": 33.86840742501954,
        "90": 34.49460489199846,
        "99": 34.55970495917892
    }
}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/169042
Approved by: https://github.com/aditew01, https://github.com/albanD
This commit is contained in:
chenlang
2025-12-11 12:59:48 +00:00
committed by PyTorch MergeBot
parent 8121f2c5d0
commit df8b6bd5a3
3 changed files with 63 additions and 0 deletions

View File

@@ -28,6 +28,14 @@ extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
float *beta,
float *c, int *ldc);
#endif // BLAS_HAS_SBGEMM
#ifdef BLAS_HAS_SHGEMM
extern "C" void shgemm_(char *transa, char *transb, int *m, int *n, int *k,
float *alpha,
const at::Half *a, int *lda,
const at::Half *b, int *ldb,
float *beta,
float *c, int *ldc);
#endif // BLAS_HAS_SHGEMM
extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy);
extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy);
extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy);
@@ -413,6 +421,34 @@ void gemm(
mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
return;
}
#endif
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
float alpha_ = alpha, beta_ = beta;
int c_size = n_ * m_;
// C matrix in OpenBLAS shgemm are of type "float" so we have to convert, copy and copy back.
std::vector<float> float_v(c_size, 0.0f);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
}
}
shgemm_(&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
float_v.data(), &m_);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
c[j * ldc_ + i] = c10::convert<at::Half>(float_v[j * m_ + i]);
}
}
return;
}
#endif
gemm_stub(
at::kCPU, at::kHalf,
@@ -471,6 +507,21 @@ void gemm(
const float beta,
float *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
float alpha_ = alpha, beta_ = beta;
shgemm_(&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
c, &ldc_);
return;
}
#endif
#ifdef MKL_HAS_SHGEMM
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;

View File

@@ -346,3 +346,14 @@ IF(BLAS_LIBRARIES)
add_compile_options(-DBLAS_HAS_SBGEMM)
ENDIF(BLAS_HAS_SBGEMM)
ENDIF(BLAS_LIBRARIES)
# Blas has fp16 (half precision) support?
IF(BLAS_LIBRARIES)
INCLUDE(CheckFunctionExists)
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
check_function_exists("shgemm_" BLAS_HAS_SHGEMM)
set(CMAKE_REQUIRED_LIBRARIES)
IF(BLAS_HAS_SHGEMM)
add_compile_options(-DBLAS_HAS_SHGEMM)
ENDIF(BLAS_HAS_SHGEMM)
ENDIF(BLAS_LIBRARIES)

View File

@@ -60,6 +60,7 @@ function(caffe2_print_configuration_summary)
if(${USE_BLAS})
message(STATUS " BLAS : ${BLAS_INFO}")
message(STATUS " BLAS_HAS_SBGEMM : ${BLAS_HAS_SBGEMM}")
message(STATUS " BLAS_HAS_SHGEMM : ${BLAS_HAS_SHGEMM}")
endif()
message(STATUS " USE_LAPACK : ${USE_LAPACK}")
if(${USE_LAPACK})