Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing gpu support for Gelu and some other ops #2825

Open
wants to merge 1 commit into
base: develop-upstream
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/image/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ limitations under the License.

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/kernels/image/image_ops.h"

Expand Down Expand Up @@ -192,7 +192,7 @@ TF_CALL_bfloat16(REGISTER);

#undef REGISTER

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

typedef Eigen::GpuDevice GPUDevice;
typedef generator::Mode Mode;
Expand Down Expand Up @@ -266,6 +266,6 @@ TF_CALL_double(REGISTER);

#undef REGISTER

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow
5 changes: 2 additions & 3 deletions tensorflow/core/kernels/image/image_ops_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand All @@ -28,7 +28,6 @@ namespace functor {

// Explicit instantiation of the GPU functor.
typedef Eigen::GpuDevice GPUDevice;

template class FillProjectiveTransform<GPUDevice, uint8>;
template class FillProjectiveTransform<GPUDevice, int32>;
template class FillProjectiveTransform<GPUDevice, int64>;
Expand All @@ -40,4 +39,4 @@ template class FillProjectiveTransform<GPUDevice, double>;

} // end namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
8 changes: 0 additions & 8 deletions tensorflow/core/kernels/relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ TF_CALL_half(DECLARE_GPU_NO_MLIR_SPEC);
TF_CALL_float(DECLARE_GPU_NO_MLIR_SPEC);
TF_CALL_double(DECLARE_GPU_NO_MLIR_SPEC);
#endif
#if GOOGLE_CUDA //No Rocm for now
TF_CALL_bfloat16(DECLARE_GPU_NO_MLIR_SPEC);
#endif
#undef DECLARE_GPU_NO_MLIR_SPEC
} // namespace functor

Expand All @@ -138,9 +136,7 @@ TF_CALL_half(REGISTER_GPU_NO_MLIR_KERNELS);
TF_CALL_float(REGISTER_GPU_NO_MLIR_KERNELS);
TF_CALL_double(REGISTER_GPU_NO_MLIR_KERNELS);
#endif
#if GOOGLE_CUDA //No Rocm for now
TF_CALL_bfloat16(REGISTER_GPU_NO_MLIR_KERNELS);
#endif
#undef REGISTER_GPU_NO_MLIR_KERNELS

// Forward declarations of the functor specializations for GPU.
Expand Down Expand Up @@ -210,9 +206,7 @@ void Relu<GPUDevice, qint8>::operator()(
extern template struct Relu<GPUDevice, qint8>;

TF_CALL_GPU_NUMBER_TYPES_NO_BF16(DECLARE_GPU_SPEC);
#if GOOGLE_CUDA
TF_CALL_bfloat16(DECLARE_GPU_SPEC);
#endif
} // namespace functor

// Registration of the GPU implementations.
Expand Down Expand Up @@ -246,9 +240,7 @@ TF_CALL_bfloat16(DECLARE_GPU_SPEC);
SeluGradOp<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES_NO_BF16(REGISTER_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
#endif
#undef REGISTER_GPU_KERNELS

template <typename Device>
Expand Down
103 changes: 48 additions & 55 deletions tensorflow/core/kernels/relu_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,68 +234,65 @@ struct Relu<Device, qint8> {
reinterpret_cast<int32*>(output.data())));
}
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <class T>
__global__ void GeluKernel(const T* in, T* out, int32 count) {
__global__ void GeluKernel(const T* __restrict__ in,
T* __restrict__ out, int32 count) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= count) return;
const auto scale = static_cast<T>(0.7978845608028654);
const auto p1 = scale;
const auto p3 = static_cast<T>(0.044715 * 0.7978845608028654);
T x = in[i];
out[i] = 0.5 * x * (1 + tanh(p1 * x + p3 * x * x * x));

constexpr bool is_half = std::is_same_v<T, Eigen::half>;
if constexpr(is_half || std::is_same_v<T, Eigen::bfloat16>) {
using NT = std::conditional_t< is_half, half, bfloat16 >;
auto *xin = reinterpret_cast<const NT*>(in);
auto *xout = reinterpret_cast<NT*>(out);
const float scale = 0.7978845608028654;
const float p1 = scale;
const float p3 = 0.044715 * 0.7978845608028654;
float x = xin[i];
float out = 0.5f * x * (1.f + tanh(p1 * x + p3 * x * x * x));
xout[i] = static_cast<NT>(out);
} else {
const auto scale = static_cast<T>(0.7978845608028654);
const auto p1 = scale;
const auto p3 = static_cast<T>(0.044715 * 0.7978845608028654);
T x = in[i];
out[i] = 0.5 * x * (1. + tanh(p1 * x + p3 * x * x * x));
}
}

template <class T>
__global__ void GeluGradKernel(const T* gradient, const T* feature, T* backprop,
int32 count) {
__global__ void GeluGradKernel(const T* __restrict__ gradient,
const T* __restrict__ feature, T* __restrict__ backprop, int32 count) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= count) return;

const T p1 = static_cast<T>(0.7978845608028654);
const T p3 = static_cast<T>(0.044715 * 0.7978845608028654);
T x = feature[i];
T z = p1 * x + p3 * x * x * x;
T g = gradient[i];
T cz = 1. / cosh(z);
backprop[i] = static_cast<T>(
constexpr bool is_half = std::is_same_v<T, Eigen::half>;
if constexpr(is_half || std::is_same_v<T, Eigen::bfloat16>) {
using NT = std::conditional_t< is_half, half, bfloat16 >;
const float scale = 0.7978845608028654;
const float p1 = scale;
const float p3 = 0.044715 * 0.7978845608028654;
auto *xgrad = reinterpret_cast<const NT*>(gradient);
auto *xfeature = reinterpret_cast<const NT*>(feature);
auto *xbackprop = reinterpret_cast<NT*>(backprop);
float x = xfeature[i];
float z = p1 * x + p3 * x * x * x;
float g = xgrad[i];
float cz = 1.f / cosh(z);
float out = g * 0.5f * (1.f + tanh(z) +
x * (p1 + 3 * p3 * x * x) * cz * cz);
xbackprop[i] = static_cast< NT >(out);
} else {
const T p1 = static_cast<T>(0.7978845608028654);
const T p3 = static_cast<T>(0.044715 * 0.7978845608028654);
T x = feature[i];
T z = p1 * x + p3 * x * x * x;
T g = gradient[i];
T cz = 1. / cosh(z);
backprop[i] = static_cast<T>(
g * 0.5 * (1. + tanh(z) + x * (p1 + 3 * p3 * x * x) * cz * cz));
}

template <>
__global__ void GeluKernel<Eigen::half>(const Eigen::half* _in,
Eigen::half* _out, int32 count) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= count) return;
const half* in = reinterpret_cast<const half*>(_in);
half* out = reinterpret_cast<half*>(_out);
const float scale = 0.7978845608028654;
const float p1 = scale;
const float p3 = 0.044715 * 0.7978845608028654;
float x = in[i];
out[i] = 0.5 * x * (1 + tanh(p1 * x + p3 * x * x * x));
}

template <>
__global__ void GeluGradKernel<Eigen::half>(const Eigen::half* _gradient,
const Eigen::half* _feature,
Eigen::half* _backprop,
int32 count) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= count) return;
const float scale = 0.7978845608028654;
const float p1 = scale;
const float p3 = 0.044715 * 0.7978845608028654;
const half* gradient = reinterpret_cast<const half*>(_gradient);
const half* feature = reinterpret_cast<const half*>(_feature);
half* backprop = reinterpret_cast<half*>(_backprop);
float x = feature[i];
float z = p1 * x + p3 * x * x * x;
float g = gradient[i];
float cz = 1. / cosh(z);
backprop[i] = g * 0.5 * (1. + tanh(z) + x * (p1 + 3 * p3 * x * x) * cz * cz);
}
}

template <typename T>
Expand Down Expand Up @@ -338,9 +335,7 @@ TF_CALL_half(DEFINE_GPU_NO_MLIR_KERNELS);
TF_CALL_float(DEFINE_GPU_NO_MLIR_KERNELS);
TF_CALL_double(DEFINE_GPU_NO_MLIR_KERNELS);
#endif
#if GOOGLE_CUDA
TF_CALL_bfloat16(DEFINE_GPU_NO_MLIR_KERNELS);
#endif
#undef DEFINE_GPU_NO_MLIR_KERNELS

// Definition of the GPU implementations declared in relu_op.cc.
Expand All @@ -356,9 +351,7 @@ TF_CALL_bfloat16(DEFINE_GPU_NO_MLIR_KERNELS);
template struct functor::GeluGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES_NO_BF16(DEFINE_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#endif
template struct functor::Relu<GPUDevice, qint8>;

} // end namespace tensorflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ TF_CALL_double(REGISTER_GPU_SPEC);

} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/unique_op_gpu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,6 @@ class UniqueOpGPU : public AsyncOpKernel {

} // end namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#endif // TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/unique_op_gpu_0.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/unique_op_gpu.cu.h"
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/unique_op_gpu_1.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/unique_op_gpu.cu.h"
Expand All @@ -39,4 +39,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_UNIQUE_GPU);

} // end namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 changes: 15 additions & 6 deletions tensorflow/python/kernel_tests/nn_ops/relu_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def testNumbersGPU(self):
self.skipTest("No GPU available")
for t in [
np.float16,
dtypes.bfloat16.as_numpy_dtype,
np.float32,
np.float64,
dtypes.bfloat16.as_numpy_dtype,
]:
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t))
Expand Down Expand Up @@ -550,17 +550,25 @@ def testNumbersCPU(self):
def testNumbersGPU(self):
if not test.is_gpu_available():
self.skipTest("No GPU available")
for t in [np.float16, np.float32, np.float64]:
for t in [np.float16, dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]:
self._testGelu(np.array([[-9, 7, -5, 3, -1],
[1, -3, 5, -7, 9]]).astype(t))

def testGradients(self):
for t in [np.float16, np.float32, np.float64]:
for t in [np.float16, dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]:

is_f16 = t == np.float16
is_bf16 = t == dtypes.bfloat16.as_numpy_dtype
for gpu in [True, False]:
if gpu and not test.is_gpu_available():
continue
delta = 2e-2 if t == np.float16 else 1e-3
tol = 2e-2 if t == np.float16 else (1e-4 if t == np.float32 else 1e-6)
delta = 2e-2 if is_f16 or is_bf16 else 1e-3
tol = 3e-2 if is_bf16 else \
2e-2 if is_f16 else \
1e-4 if t == np.float32 else 1e-6
if is_bf16 and not gpu:
tol = 0.1 # really bad accuracy on CPU for bf16

def approx_gelu(x):
return nn_ops.gelu(x, approximate=True)
with self.session(use_gpu=gpu):
Expand All @@ -571,7 +579,8 @@ def approx_gelu(x):
err = gradient_checker_v2.max_error(
e1, e2)
print(e1, e2)
print("gelu", t, "GPU" if gpu else "CPU", "gradient err = ", err)
print("gelu", t, "GPU" if gpu else "CPU", \
"gradient err = ", err, " tol = ", tol)
self.assertLess(err, tol)

class SeluTest(test.TestCase):
Expand Down