From cbcf45274b04989e13fff00a078de95768aac3b7 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Wed, 11 Jul 2018 13:42:04 -0700 Subject: [PATCH] Move tanh function to math (#9328) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9328 Move tanh function to math Reviewed By: houseroad Differential Revision: D8794745 fbshipit-source-id: ea525dedde6f53592b06c2caffd6426688dea5fc --- caffe2/operators/tanh_op.cc | 7 ++----- caffe2/operators/tanh_op.cu | 23 ----------------------- caffe2/operators/tanh_op.h | 5 ++++- caffe2/python/brew_test.py | 4 ++-- caffe2/utils/math.h | 2 ++ caffe2/utils/math_cpu.cc | 13 +++++++++++++ caffe2/utils/math_gpu.cu | 1 + 7 files changed, 24 insertions(+), 31 deletions(-) diff --git a/caffe2/operators/tanh_op.cc b/caffe2/operators/tanh_op.cc index 28ca87c13a2..312fdab36fd 100644 --- a/caffe2/operators/tanh_op.cc +++ b/caffe2/operators/tanh_op.cc @@ -2,6 +2,7 @@ namespace caffe2 { +#ifdef CAFFE2_USE_ACCELERATE template <> template <> bool TanhFunctor::operator()( @@ -9,14 +10,10 @@ bool TanhFunctor::operator()( const float* X, float* Y, CPUContext* /* context */) const { -#ifdef CAFFE2_USE_ACCELERATE vvtanhf(Y, X, &N); -#else - ConstEigenVectorArrayMap X_arr(X, N); - EigenVectorMap(Y, N) = 1 - 2 * ((X_arr * 2).exp() + 1).inverse(); -#endif return true; } +#endif // CAFFE2_USE_ACCELERATE REGISTER_CPU_OPERATOR( Tanh, diff --git a/caffe2/operators/tanh_op.cu b/caffe2/operators/tanh_op.cu index ff0ab2b558a..17ebac1ed16 100644 --- a/caffe2/operators/tanh_op.cu +++ b/caffe2/operators/tanh_op.cu @@ -9,17 +9,6 @@ namespace caffe2 { namespace { -template -__global__ void TanhCUDAKernel(const int N, const T* X, T* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { -#if __CUDA_ARCH__ >= 350 - Y[i] = tanh(__ldg(X + i)); -#else - Y[i] = tanh(X[i]); -#endif - } -} - template __global__ void TanhGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) { @@ -34,18 +23,6 @@ TanhGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) { } // namespace -template <> -template -bool TanhFunctor:: -operator()(const int N, const T* X, T* Y, CUDAContext* context) const { - TanhCUDAKernel - <<cuda_stream()>>>(N, X, Y); - return true; -} - template <> template bool TanhGradientFunctor::Forward( diff --git a/caffe2/operators/tanh_op.h b/caffe2/operators/tanh_op.h index 117767b3c6e..123773dfff0 100644 --- a/caffe2/operators/tanh_op.h +++ b/caffe2/operators/tanh_op.h @@ -11,7 +11,10 @@ namespace caffe2 { template struct TanhFunctor { template - bool operator()(const int N, const T* X, T* Y, Context* context) const; + bool operator()(const int N, const T* X, T* Y, Context* context) const { + math::Tanh(N, X, Y, context); + return true; + } }; template diff --git a/caffe2/python/brew_test.py b/caffe2/python/brew_test.py index 17b2f57c508..8b3d08977c2 100644 --- a/caffe2/python/brew_test.py +++ b/caffe2/python/brew_test.py @@ -81,7 +81,7 @@ class BrewTest(unittest.TestCase): workspace.RunNetOnce(model.net) out = workspace.FetchBlob("out_tanh") - self.assertAlmostEqual(out.mean(), 0.46211711) + self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5) def test_validate(self): model = ModelHelper(name="test_model") @@ -325,4 +325,4 @@ class BrewGPUTest(unittest.TestCase): workspace.RunNetOnce(model.net) out = workspace.FetchBlob("out_tanh") - self.assertAlmostEqual(out.mean(), 0.46211711) + self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5) diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index fe1da186474..0e7bc2b56f6 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -78,6 +78,8 @@ void Cosh(const int N, const T* x, T* y, Context* context); template void SinCos(const int N, const T* x, T* ys, T* yc, Context* context); template +void Tanh(const int N, const T* x, T* y, Context* context); +template void Abs(const int N, const T* x, T* y, Context* context); template void Sqr(const int N, const T* x, T* y, Context* context); diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index feb44a762c3..c958f9f43a2 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -556,6 +556,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sinh, vsSinh) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sinh, vdSinh) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cosh, vsCosh) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cosh, vdCosh) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, vsTanh) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, vdTanh) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Abs, vsAbs) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Abs, vdAbs) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr) @@ -636,6 +638,17 @@ DELEGATE_SINCOS_FUNCTION(float) DELEGATE_SINCOS_FUNCTION(double) #undef DELEGATE_SINCOS_FUNCTION +#define DELEGATE_TANH_FUNCTION(T) \ + template <> \ + void Tanh(const int N, const T* X, T* Y, CPUContext*) { \ + EigenVectorMap(Y, N) = T(1) - \ + ((ConstEigenVectorArrayMap(X, N) * T(2)).exp() + T(1)).inverse() * \ + T(2); \ + } +DELEGATE_TANH_FUNCTION(float) +DELEGATE_TANH_FUNCTION(double) +#undef DELEGATE_TANH_FUNCTION + #define DELEGATE_CBRT_FUNCTION(T) \ template <> \ void Cbrt(const int N, const T* X, T* Y, CPUContext*) { \ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 17e941c9758..4b2a0661642 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -331,6 +331,7 @@ DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf)