Move tanh function to math (#9328)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9328

Move tanh function to math

Reviewed By: houseroad

Differential Revision: D8794745

fbshipit-source-id: ea525dedde6f53592b06c2caffd6426688dea5fc
This commit is contained in:
Xiaomeng Yang
2018-07-11 13:42:04 -07:00
committed by Facebook Github Bot
parent 7d8b532c1f
commit cbcf45274b
7 changed files with 24 additions and 31 deletions

View File

@@ -2,6 +2,7 @@
namespace caffe2 {
#ifdef CAFFE2_USE_ACCELERATE
template <>
template <>
bool TanhFunctor<CPUContext>::operator()<float>(
@@ -9,14 +10,10 @@ bool TanhFunctor<CPUContext>::operator()<float>(
const float* X,
float* Y,
CPUContext* /* context */) const {
#ifdef CAFFE2_USE_ACCELERATE
vvtanhf(Y, X, &N);
#else
ConstEigenVectorArrayMap<float> X_arr(X, N);
EigenVectorMap<float>(Y, N) = 1 - 2 * ((X_arr * 2).exp() + 1).inverse();
#endif
return true;
}
#endif // CAFFE2_USE_ACCELERATE
REGISTER_CPU_OPERATOR(
Tanh,

View File

@@ -9,17 +9,6 @@ namespace caffe2 {
namespace {
template <typename T>
__global__ void TanhCUDAKernel(const int N, const T* X, T* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
Y[i] = tanh(__ldg(X + i));
#else
Y[i] = tanh(X[i]);
#endif
}
}
template <typename T>
__global__ void
TanhGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
@@ -34,18 +23,6 @@ TanhGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
} // namespace
template <>
template <typename T>
bool TanhFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
TanhCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(N, X, Y);
return true;
}
template <>
template <typename T>
bool TanhGradientFunctor<CUDAContext>::Forward(

View File

@@ -11,7 +11,10 @@ namespace caffe2 {
template <class Context>
struct TanhFunctor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const;
bool operator()(const int N, const T* X, T* Y, Context* context) const {
math::Tanh<T, Context>(N, X, Y, context);
return true;
}
};
template <class Context>

View File

@@ -81,7 +81,7 @@ class BrewTest(unittest.TestCase):
workspace.RunNetOnce(model.net)
out = workspace.FetchBlob("out_tanh")
self.assertAlmostEqual(out.mean(), 0.46211711)
self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)
def test_validate(self):
model = ModelHelper(name="test_model")
@@ -325,4 +325,4 @@ class BrewGPUTest(unittest.TestCase):
workspace.RunNetOnce(model.net)
out = workspace.FetchBlob("out_tanh")
self.assertAlmostEqual(out.mean(), 0.46211711)
self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)

View File

@@ -78,6 +78,8 @@ void Cosh(const int N, const T* x, T* y, Context* context);
template <typename T, class Context>
void SinCos(const int N, const T* x, T* ys, T* yc, Context* context);
template <typename T, class Context>
void Tanh(const int N, const T* x, T* y, Context* context);
template <typename T, class Context>
void Abs(const int N, const T* x, T* y, Context* context);
template <typename T, class Context>
void Sqr(const int N, const T* x, T* y, Context* context);

View File

@@ -556,6 +556,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sinh, vsSinh)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sinh, vdSinh)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cosh, vsCosh)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cosh, vdCosh)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, vsTanh)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, vdTanh)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Abs, vsAbs)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Abs, vdAbs)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr)
@@ -636,6 +638,17 @@ DELEGATE_SINCOS_FUNCTION(float)
DELEGATE_SINCOS_FUNCTION(double)
#undef DELEGATE_SINCOS_FUNCTION
#define DELEGATE_TANH_FUNCTION(T) \
template <> \
void Tanh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
EigenVectorMap<T>(Y, N) = T(1) - \
((ConstEigenVectorArrayMap<T>(X, N) * T(2)).exp() + T(1)).inverse() * \
T(2); \
}
DELEGATE_TANH_FUNCTION(float)
DELEGATE_TANH_FUNCTION(double)
#undef DELEGATE_TANH_FUNCTION
#define DELEGATE_CBRT_FUNCTION(T) \
template <> \
void Cbrt<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \

View File

@@ -331,6 +331,7 @@ DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square<float>)
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf)