Optimize half conversion for SYCL kernel

## Motivation:
Add support for SYCL half implicit/explicit conversion in SYCL kernels.

## Additional Context:
Macro `SYCL_LANGUAGE_VERSION` is suggested by SYCL compiler to instead of `__SYCL_DEVICE_ONLY__` in current version unless device and host specific implementation of the same function is necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76515
Approved by: https://github.com/ezyang
This commit is contained in:
cchheennhhaaoo
2022-05-04 00:57:03 +00:00
committed by PyTorch MergeBot
parent ed18181d83
commit 68e012b023
2 changed files with 20 additions and 1 deletions

View File

@@ -12,7 +12,7 @@
#include <hip/hip_fp16.h>
#endif
#ifdef __SYCL_DEVICE_ONLY__
#ifdef SYCL_LANGUAGE_VERSION
#include <CL/sycl.hpp>
#endif
@@ -56,6 +56,15 @@ inline C10_HOST_DEVICE Half::operator __half() const {
}
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE Half::operator sycl::half() const {
return *reinterpret_cast<const sycl::half*>(&x);
}
#endif
// CUDA intrinsics
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
@@ -88,6 +97,8 @@ inline C10_HOST_DEVICE Half operator-(const Half& a) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
defined(__HIP_DEVICE_COMPILE__)
return __hneg(a);
#elif defined(__SYCL_DEVICE_ONLY__)
return -static_cast<sycl::half>(a);
#else
return -static_cast<float>(a);
#endif

View File

@@ -45,6 +45,10 @@
#include <hip/hip_fp16.h>
#endif
#ifdef SYCL_LANGUAGE_VERSION
#include <CL/sycl.hpp>
#endif
// Standard check for compiling CUDA with clang
#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
#define C10_DEVICE_HOST_FUNCTION __device__ __host__
@@ -390,6 +394,10 @@ struct alignas(2) Half {
inline C10_HOST_DEVICE Half(const __half& value);
inline C10_HOST_DEVICE operator __half() const;
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half(const sycl::half& value);
inline C10_HOST_DEVICE operator sycl::half() const;
#endif
};
// TODO : move to complex.h