mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Optimize half conversion for SYCL kernel
## Motivation: Add support for SYCL half implicit/explicit conversion in SYCL kernels. ## Additional Context: Macro `SYCL_LANGUAGE_VERSION` is suggested by SYCL compiler to instead of `__SYCL_DEVICE_ONLY__` in current version unless device and host specific implementation of the same function is necessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76515 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ed18181d83
commit
68e012b023
@@ -12,7 +12,7 @@
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#ifdef __SYCL_DEVICE_ONLY__
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
#include <CL/sycl.hpp>
|
||||
#endif
|
||||
|
||||
@@ -56,6 +56,15 @@ inline C10_HOST_DEVICE Half::operator __half() const {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half::operator sycl::half() const {
|
||||
return *reinterpret_cast<const sycl::half*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// CUDA intrinsics
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
|
||||
@@ -88,6 +97,8 @@ inline C10_HOST_DEVICE Half operator-(const Half& a) {
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
|
||||
defined(__HIP_DEVICE_COMPILE__)
|
||||
return __hneg(a);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
return -static_cast<sycl::half>(a);
|
||||
#else
|
||||
return -static_cast<float>(a);
|
||||
#endif
|
||||
|
||||
@@ -45,6 +45,10 @@
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
#include <CL/sycl.hpp>
|
||||
#endif
|
||||
|
||||
// Standard check for compiling CUDA with clang
|
||||
#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
|
||||
#define C10_DEVICE_HOST_FUNCTION __device__ __host__
|
||||
@@ -390,6 +394,10 @@ struct alignas(2) Half {
|
||||
inline C10_HOST_DEVICE Half(const __half& value);
|
||||
inline C10_HOST_DEVICE operator __half() const;
|
||||
#endif
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
inline C10_HOST_DEVICE Half(const sycl::half& value);
|
||||
inline C10_HOST_DEVICE operator sycl::half() const;
|
||||
#endif
|
||||
};
|
||||
|
||||
// TODO : move to complex.h
|
||||
|
||||
Reference in New Issue
Block a user