diff --git a/tensorflow/core/kernels/image/image_ops.cc b/tensorflow/core/kernels/image/image_ops.cc index 649ad187c47439..166aeb56b451a1 100644 --- a/tensorflow/core/kernels/image/image_ops.cc +++ b/tensorflow/core/kernels/image/image_ops.cc @@ -15,9 +15,9 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/image/image_ops.h" @@ -192,7 +192,7 @@ TF_CALL_bfloat16(REGISTER); #undef REGISTER -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM typedef Eigen::GpuDevice GPUDevice; typedef generator::Mode Mode; @@ -266,6 +266,6 @@ TF_CALL_double(REGISTER); #undef REGISTER -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/image/image_ops_gpu.cu.cc b/tensorflow/core/kernels/image/image_ops_gpu.cu.cc index dd94559ffd7d69..b602ba67df6277 100644 --- a/tensorflow/core/kernels/image/image_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/image/image_ops_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -28,7 +28,6 @@ namespace functor { // Explicit instantiation of the GPU functor. typedef Eigen::GpuDevice GPUDevice; - template class FillProjectiveTransform; template class FillProjectiveTransform; template class FillProjectiveTransform; @@ -40,4 +39,4 @@ template class FillProjectiveTransform; } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index 6a0560632cb278..75087ab6c9ccc9 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -116,9 +116,7 @@ TF_CALL_half(DECLARE_GPU_NO_MLIR_SPEC); TF_CALL_float(DECLARE_GPU_NO_MLIR_SPEC); TF_CALL_double(DECLARE_GPU_NO_MLIR_SPEC); #endif -#if GOOGLE_CUDA //No Rocm for now TF_CALL_bfloat16(DECLARE_GPU_NO_MLIR_SPEC); -#endif #undef DECLARE_GPU_NO_MLIR_SPEC } // namespace functor @@ -138,9 +136,7 @@ TF_CALL_half(REGISTER_GPU_NO_MLIR_KERNELS); TF_CALL_float(REGISTER_GPU_NO_MLIR_KERNELS); TF_CALL_double(REGISTER_GPU_NO_MLIR_KERNELS); #endif -#if GOOGLE_CUDA //No Rocm for now TF_CALL_bfloat16(REGISTER_GPU_NO_MLIR_KERNELS); -#endif #undef REGISTER_GPU_NO_MLIR_KERNELS // Forward declarations of the functor specializations for GPU. @@ -210,9 +206,7 @@ void Relu::operator()( extern template struct Relu; TF_CALL_GPU_NUMBER_TYPES_NO_BF16(DECLARE_GPU_SPEC); -#if GOOGLE_CUDA TF_CALL_bfloat16(DECLARE_GPU_SPEC); -#endif } // namespace functor // Registration of the GPU implementations. @@ -246,9 +240,7 @@ TF_CALL_bfloat16(DECLARE_GPU_SPEC); SeluGradOp) TF_CALL_GPU_NUMBER_TYPES_NO_BF16(REGISTER_GPU_KERNELS); -#if GOOGLE_CUDA TF_CALL_bfloat16(REGISTER_GPU_KERNELS); -#endif #undef REGISTER_GPU_KERNELS template diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index d4ed4980841d5f..a1bba19fc27506 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -234,68 +234,65 @@ struct Relu { reinterpret_cast(output.data()))); } }; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template -__global__ void GeluKernel(const T* in, T* out, int32 count) { +__global__ void GeluKernel(const T* __restrict__ in, + T* __restrict__ out, int32 count) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= count) return; - const auto scale = static_cast(0.7978845608028654); - const auto p1 = scale; - const auto p3 = static_cast(0.044715 * 0.7978845608028654); - T x = in[i]; - out[i] = 0.5 * x * (1 + tanh(p1 * x + p3 * x * x * x)); + + constexpr bool is_half = std::is_same_v; + if constexpr(is_half || std::is_same_v) { + using NT = std::conditional_t< is_half, half, bfloat16 >; + auto *xin = reinterpret_cast(in); + auto *xout = reinterpret_cast(out); + const float scale = 0.7978845608028654; + const float p1 = scale; + const float p3 = 0.044715 * 0.7978845608028654; + float x = xin[i]; + float out = 0.5f * x * (1.f + tanh(p1 * x + p3 * x * x * x)); + xout[i] = static_cast(out); + } else { + const auto scale = static_cast(0.7978845608028654); + const auto p1 = scale; + const auto p3 = static_cast(0.044715 * 0.7978845608028654); + T x = in[i]; + out[i] = 0.5 * x * (1. + tanh(p1 * x + p3 * x * x * x)); + } } template -__global__ void GeluGradKernel(const T* gradient, const T* feature, T* backprop, - int32 count) { +__global__ void GeluGradKernel(const T* __restrict__ gradient, + const T* __restrict__ feature, T* __restrict__ backprop, int32 count) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= count) return; - const T p1 = static_cast(0.7978845608028654); - const T p3 = static_cast(0.044715 * 0.7978845608028654); - T x = feature[i]; - T z = p1 * x + p3 * x * x * x; - T g = gradient[i]; - T cz = 1. / cosh(z); - backprop[i] = static_cast( + constexpr bool is_half = std::is_same_v; + if constexpr(is_half || std::is_same_v) { + using NT = std::conditional_t< is_half, half, bfloat16 >; + const float scale = 0.7978845608028654; + const float p1 = scale; + const float p3 = 0.044715 * 0.7978845608028654; + auto *xgrad = reinterpret_cast(gradient); + auto *xfeature = reinterpret_cast(feature); + auto *xbackprop = reinterpret_cast(backprop); + float x = xfeature[i]; + float z = p1 * x + p3 * x * x * x; + float g = xgrad[i]; + float cz = 1.f / cosh(z); + float out = g * 0.5f * (1.f + tanh(z) + + x * (p1 + 3 * p3 * x * x) * cz * cz); + xbackprop[i] = static_cast< NT >(out); + } else { + const T p1 = static_cast(0.7978845608028654); + const T p3 = static_cast(0.044715 * 0.7978845608028654); + T x = feature[i]; + T z = p1 * x + p3 * x * x * x; + T g = gradient[i]; + T cz = 1. / cosh(z); + backprop[i] = static_cast( g * 0.5 * (1. + tanh(z) + x * (p1 + 3 * p3 * x * x) * cz * cz)); -} - -template <> -__global__ void GeluKernel(const Eigen::half* _in, - Eigen::half* _out, int32 count) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= count) return; - const half* in = reinterpret_cast(_in); - half* out = reinterpret_cast(_out); - const float scale = 0.7978845608028654; - const float p1 = scale; - const float p3 = 0.044715 * 0.7978845608028654; - float x = in[i]; - out[i] = 0.5 * x * (1 + tanh(p1 * x + p3 * x * x * x)); -} - -template <> -__global__ void GeluGradKernel(const Eigen::half* _gradient, - const Eigen::half* _feature, - Eigen::half* _backprop, - int32 count) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= count) return; - const float scale = 0.7978845608028654; - const float p1 = scale; - const float p3 = 0.044715 * 0.7978845608028654; - const half* gradient = reinterpret_cast(_gradient); - const half* feature = reinterpret_cast(_feature); - half* backprop = reinterpret_cast(_backprop); - float x = feature[i]; - float z = p1 * x + p3 * x * x * x; - float g = gradient[i]; - float cz = 1. / cosh(z); - backprop[i] = g * 0.5 * (1. + tanh(z) + x * (p1 + 3 * p3 * x * x) * cz * cz); + } } template @@ -338,9 +335,7 @@ TF_CALL_half(DEFINE_GPU_NO_MLIR_KERNELS); TF_CALL_float(DEFINE_GPU_NO_MLIR_KERNELS); TF_CALL_double(DEFINE_GPU_NO_MLIR_KERNELS); #endif -#if GOOGLE_CUDA TF_CALL_bfloat16(DEFINE_GPU_NO_MLIR_KERNELS); -#endif #undef DEFINE_GPU_NO_MLIR_KERNELS // Definition of the GPU implementations declared in relu_op.cc. @@ -356,9 +351,7 @@ TF_CALL_bfloat16(DEFINE_GPU_NO_MLIR_KERNELS); template struct functor::GeluGrad; TF_CALL_GPU_NUMBER_TYPES_NO_BF16(DEFINE_GPU_KERNELS); -#if GOOGLE_CUDA TF_CALL_bfloat16(DEFINE_GPU_KERNELS); -#endif template struct functor::Relu; } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc b/tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc index a2b25f846ef1d3..e3430af3319616 100644 --- a/tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc +++ b/tensorflow/core/kernels/stateless_random_gamma_op_gpu.cu.cc @@ -183,4 +183,4 @@ TF_CALL_double(REGISTER_GPU_SPEC); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/unique_op_gpu.cu.h b/tensorflow/core/kernels/unique_op_gpu.cu.h index 735af617dc218c..26c0a606f6843b 100644 --- a/tensorflow/core/kernels/unique_op_gpu.cu.h +++ b/tensorflow/core/kernels/unique_op_gpu.cu.h @@ -449,6 +449,6 @@ class UniqueOpGPU : public AsyncOpKernel { } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_ diff --git a/tensorflow/core/kernels/unique_op_gpu_0.cu.cc b/tensorflow/core/kernels/unique_op_gpu_0.cu.cc index baf655d990c843..f45ca85862cec0 100644 --- a/tensorflow/core/kernels/unique_op_gpu_0.cu.cc +++ b/tensorflow/core/kernels/unique_op_gpu_0.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/unique_op_gpu.cu.h" diff --git a/tensorflow/core/kernels/unique_op_gpu_1.cu.cc b/tensorflow/core/kernels/unique_op_gpu_1.cu.cc index 44d649e84b6fef..f645717a2885ee 100644 --- a/tensorflow/core/kernels/unique_op_gpu_1.cu.cc +++ b/tensorflow/core/kernels/unique_op_gpu_1.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/unique_op_gpu.cu.h" @@ -39,4 +39,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_UNIQUE_GPU); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/python/kernel_tests/nn_ops/relu_op_test.py b/tensorflow/python/kernel_tests/nn_ops/relu_op_test.py index e20fa6c9ac1bf8..0a875e94ba3676 100644 --- a/tensorflow/python/kernel_tests/nn_ops/relu_op_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/relu_op_test.py @@ -70,9 +70,9 @@ def testNumbersGPU(self): self.skipTest("No GPU available") for t in [ np.float16, + dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, - dtypes.bfloat16.as_numpy_dtype, ]: self._testRelu( np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) @@ -550,17 +550,25 @@ def testNumbersCPU(self): def testNumbersGPU(self): if not test.is_gpu_available(): self.skipTest("No GPU available") - for t in [np.float16, np.float32, np.float64]: + for t in [np.float16, dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]: self._testGelu(np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) def testGradients(self): - for t in [np.float16, np.float32, np.float64]: + for t in [np.float16, dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]: + + is_f16 = t == np.float16 + is_bf16 = t == dtypes.bfloat16.as_numpy_dtype for gpu in [True, False]: if gpu and not test.is_gpu_available(): continue - delta = 2e-2 if t == np.float16 else 1e-3 - tol = 2e-2 if t == np.float16 else (1e-4 if t == np.float32 else 1e-6) + delta = 2e-2 if is_f16 or is_bf16 else 1e-3 + tol = 3e-2 if is_bf16 else \ + 2e-2 if is_f16 else \ + 1e-4 if t == np.float32 else 1e-6 + if is_bf16 and not gpu: + tol = 0.1 # really bad accuracy on CPU for bf16 + def approx_gelu(x): return nn_ops.gelu(x, approximate=True) with self.session(use_gpu=gpu): @@ -571,7 +579,8 @@ def approx_gelu(x): err = gradient_checker_v2.max_error( e1, e2) print(e1, e2) - print("gelu", t, "GPU" if gpu else "CPU", "gradient err = ", err) + print("gelu", t, "GPU" if gpu else "CPU", \ + "gradient err = ", err, " tol = ", tol) self.assertLess(err, tol) class SeluTest(test.TestCase):