mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
support fp16 shgemm under openblas (#169042)
# Purpose
This PR is to support fp16 shgemm under openblas. We conducted tests using vLLM on the following platform.
With using this patch, vLLM demonstrates faster inference speed under fp16.
**Platform info:**
Architecture: riscv64
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: 0x5b7
BIOS Vendor ID: SOPHGO
Model name: -
BIOS Model name: SG2044 Not Set CPU @ 2.6GHz
BIOS CPU family: 513
CPU family: 0x80000000090c0d00
Model: 0x2047000
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 1
Frequency boost: disabled
CPU(s) scaling MHz: 100%
CPU max MHz: 2600.0000
CPU min MHz: 1000.0000
Caches (sum of all):
L1d: 4 MiB (64 instances)
L1i: 4 MiB (64 instances)
L2: 32 MiB (16 instances)
L3: 64 MiB (1 instance)
Vulnerabilities:
Gather data sampling: Not affected
Itlb multihit: Not affected
L1tf: Not affected
Mds: Not affected
Meltdown: Not affected
Mmio stale data: Not affected
Reg file data sampling: Not affected
Retbleed: Not affected
ISA: rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintntl_zihintpause_zihpm_zawrs_zfa_zfh_zfhmin_zca_zcb_zcd_zba_zbb_zbc_zbs_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_sscofpmf_sstc_svinval_svnapot_svpbmt
**Branch**
openblas: develop
torch: develop
vllm: main
# Test Plan
Base: without this PR
Pytorch use OpenBLAS FP16 GEMM: use this PR
**Base**
export VLLM_CPU_OMP_THREADS_BIND=0-63
export VLLM_CPU_KVCACHE_SPACE=60
vllm bench latency \
--model /home/models/Qwen2.5-7B-Instruct \
--tensor-parallel-size 1\
--dtype float16 \
--input-len 16 \
--output-len 16 \
--enforce-eager \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--batch-size 1 \
--n 1 \
--num-iters-warmup 5 \
--num-iters 8 \
--seed 42 \
--output-json ./latency_results_fp16_latency_base.json
**Pytorch use OpenBLAS FP16 GEMM**
export VLLM_CPU_OMP_THREADS_BIND=0-63
export VLLM_CPU_KVCACHE_SPACE=60
vllm bench latency \
--model /home/models/Qwen2.5-7B-Instruct \
--tensor-parallel-size 1\
--dtype float16 \
--input-len 16 \
--output-len 16 \
--enforce-eager \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--batch-size 1 \
--n 1 \
--num-iters-warmup 5 \
--num-iters 8 \
--seed 42 \
--output-json ./latency_results_fp16_latency_with_openblas_support.json
# Result
**Base**
{
"avg_latency": 62.53946338250171,
"latencies": [
58.46783778001554,
58.230652199999895,
58.335780619992875,
59.77051957999356,
58.587668860011036,
59.31567866000114,
58.460076240007766,
89.14749311999185
],
"percentiles": {
"10": 58.30424209399498,
"25": 58.42900233500404,
"50": 58.52775332001329,
"75": 59.429388889999245,
"90": 68.58361164199304,
"99": 87.09110497219196
}
}
**Pytorch use OpenBLAS FP16 GEMM**
{
"avg_latency": 32.42863222499727,
"latencies": [
30.742418120033108,
33.67000828002347,
29.747197599965148,
32.11275753995869,
34.566938299976755,
30.849812360014766,
34.46360486000776,
33.27632073999848
],
"percentiles": {
"10": 30.44385196401272,
"25": 30.82296380001935,
"50": 32.69453913997859,
"75": 33.86840742501954,
"90": 34.49460489199846,
"99": 34.55970495917892
}
}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169042
Approved by: https://github.com/aditew01, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
8121f2c5d0
commit
df8b6bd5a3
@@ -28,6 +28,14 @@ extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
|
||||
float *beta,
|
||||
float *c, int *ldc);
|
||||
#endif // BLAS_HAS_SBGEMM
|
||||
#ifdef BLAS_HAS_SHGEMM
|
||||
extern "C" void shgemm_(char *transa, char *transb, int *m, int *n, int *k,
|
||||
float *alpha,
|
||||
const at::Half *a, int *lda,
|
||||
const at::Half *b, int *ldb,
|
||||
float *beta,
|
||||
float *c, int *ldc);
|
||||
#endif // BLAS_HAS_SHGEMM
|
||||
extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy);
|
||||
extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy);
|
||||
extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy);
|
||||
@@ -413,6 +421,34 @@ void gemm(
|
||||
mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
|
||||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
|
||||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
|
||||
char transa_ = to_blas(transa), transb_ = to_blas(transb);
|
||||
float alpha_ = alpha, beta_ = beta;
|
||||
int c_size = n_ * m_;
|
||||
// C matrix in OpenBLAS shgemm are of type "float" so we have to convert, copy and copy back.
|
||||
std::vector<float> float_v(c_size, 0.0f);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
|
||||
}
|
||||
}
|
||||
shgemm_(&transa_, &transb_,
|
||||
&m_, &n_, &k_,
|
||||
&alpha_,
|
||||
a, &lda_,
|
||||
b, &ldb_,
|
||||
&beta_,
|
||||
float_v.data(), &m_);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
c[j * ldc_ + i] = c10::convert<at::Half>(float_v[j * m_ + i]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
gemm_stub(
|
||||
at::kCPU, at::kHalf,
|
||||
@@ -471,6 +507,21 @@ void gemm(
|
||||
const float beta,
|
||||
float *c, int64_t ldc) {
|
||||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
|
||||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
|
||||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
|
||||
char transa_ = to_blas(transa), transb_ = to_blas(transb);
|
||||
float alpha_ = alpha, beta_ = beta;
|
||||
shgemm_(&transa_, &transb_,
|
||||
&m_, &n_, &k_,
|
||||
&alpha_,
|
||||
a, &lda_,
|
||||
b, &ldb_,
|
||||
&beta_,
|
||||
c, &ldc_);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#ifdef MKL_HAS_SHGEMM
|
||||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
|
||||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
|
||||
|
||||
@@ -346,3 +346,14 @@ IF(BLAS_LIBRARIES)
|
||||
add_compile_options(-DBLAS_HAS_SBGEMM)
|
||||
ENDIF(BLAS_HAS_SBGEMM)
|
||||
ENDIF(BLAS_LIBRARIES)
|
||||
|
||||
# Blas has fp16 (half precision) support?
|
||||
IF(BLAS_LIBRARIES)
|
||||
INCLUDE(CheckFunctionExists)
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
|
||||
check_function_exists("shgemm_" BLAS_HAS_SHGEMM)
|
||||
set(CMAKE_REQUIRED_LIBRARIES)
|
||||
IF(BLAS_HAS_SHGEMM)
|
||||
add_compile_options(-DBLAS_HAS_SHGEMM)
|
||||
ENDIF(BLAS_HAS_SHGEMM)
|
||||
ENDIF(BLAS_LIBRARIES)
|
||||
|
||||
@@ -60,6 +60,7 @@ function(caffe2_print_configuration_summary)
|
||||
if(${USE_BLAS})
|
||||
message(STATUS " BLAS : ${BLAS_INFO}")
|
||||
message(STATUS " BLAS_HAS_SBGEMM : ${BLAS_HAS_SBGEMM}")
|
||||
message(STATUS " BLAS_HAS_SHGEMM : ${BLAS_HAS_SHGEMM}")
|
||||
endif()
|
||||
message(STATUS " USE_LAPACK : ${USE_LAPACK}")
|
||||
if(${USE_LAPACK})
|
||||
|
||||
Reference in New Issue
Block a user