[Mlir-commits] [mlir] [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV Conversion pass (PR #102633)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 12 06:30:42 PDT 2024


https://github.com/meehatpa updated https://github.com/llvm/llvm-project/pull/102633

>From 9f97687fa9fafd0c0170c7a8930bdd516e5aab9b Mon Sep 17 00:00:00 2001
From: Gune S <gune30 at gmail.com>
Date: Fri, 9 Aug 2024 21:11:16 +0530
Subject: [PATCH] [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV
 Conversion pass

Add missing math.atan to spirv.CL.atan and math.atan2 to spirv.CL.atan2
in MathToSPIRV.
---
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    |  3 +
 .../MathToSPIRV/math-to-gl-spirv.mlir         | 78 ++++++++--------
 .../MathToSPIRV/math-to-opencl-spirv.mlir     | 90 +++++++++++--------
 3 files changed, 97 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 0b29c93e2d8909..5b3c2fb15e7026 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -414,6 +414,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
            CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
            CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
+           CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
            CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
            CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
            CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
@@ -431,6 +432,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
                CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
                CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
+               CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
+               CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
                CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
                CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
                CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 4d0ef06d7e92f9..a9397667393429 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -6,65 +6,69 @@ module attributes {
 
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spirv.GL.Atan %{{.*}}: f32
+  %0 = math.atan %arg0 : f32
   // CHECK: spirv.GL.Cos %{{.*}}: f32
-  %0 = math.cos %arg0 : f32
+  %1 = math.cos %arg0 : f32
   // CHECK: spirv.GL.Exp %{{.*}}: f32
-  %1 = math.exp %arg0 : f32
+  %2 = math.exp %arg0 : f32
   // CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : f32
+  %3 = math.expm1 %arg0 : f32
   // CHECK: spirv.GL.Log %{{.*}}: f32
-  %3 = math.log %arg0 : f32
+  %4 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
-  %4 = math.log1p %arg0 : f32
+  %5 = math.log1p %arg0 : f32
   // CHECK: spirv.GL.RoundEven %{{.*}}: f32
-  %5 = math.roundeven %arg0 : f32
+  %6 = math.roundeven %arg0 : f32
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
-  %6 = math.rsqrt %arg0 : f32
+  %7 = math.rsqrt %arg0 : f32
   // CHECK: spirv.GL.Sqrt %{{.*}}: f32
-  %7 = math.sqrt %arg0 : f32
+  %8 = math.sqrt %arg0 : f32
   // CHECK: spirv.GL.Tanh %{{.*}}: f32
-  %8 = math.tanh %arg0 : f32
+  %9 = math.tanh %arg0 : f32
   // CHECK: spirv.GL.Sin %{{.*}}: f32
-  %9 = math.sin %arg0 : f32
+  %10 = math.sin %arg0 : f32
   // CHECK: spirv.GL.FAbs %{{.*}}: f32
-  %10 = math.absf %arg0 : f32
+  %11 = math.absf %arg0 : f32
   // CHECK: spirv.GL.Ceil %{{.*}}: f32
-  %11 = math.ceil %arg0 : f32
+  %12 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
-  %12 = math.floor %arg0 : f32
+  %13 = math.floor %arg0 : f32
   return
 }
 
 // CHECK-LABEL: @float32_unary_vector
 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
+  // CHECK: spirv.GL.Atan %{{.*}}: vector<3xf32>
+  %0 = math.atan %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Cos %{{.*}}: vector<3xf32>
-  %0 = math.cos %arg0 : vector<3xf32>
+  %1 = math.cos %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Exp %{{.*}}: vector<3xf32>
-  %1 = math.exp %arg0 : vector<3xf32>
+  %2 = math.exp %arg0 : vector<3xf32>
   // CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : vector<3xf32>
+  %3 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Log %{{.*}}: vector<3xf32>
-  %3 = math.log %arg0 : vector<3xf32>
+  %4 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
-  %4 = math.log1p %arg0 : vector<3xf32>
+  %5 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
-  %5 = math.roundeven %arg0 : vector<3xf32>
+  %6 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
-  %6 = math.rsqrt %arg0 : vector<3xf32>
+  %7 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
-  %7 = math.sqrt %arg0 : vector<3xf32>
+  %8 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
-  %8 = math.tanh %arg0 : vector<3xf32>
+  %9 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
-  %9 = math.sin %arg0 : vector<3xf32>
+  %10 = math.sin %arg0 : vector<3xf32>
   return
 }
 
@@ -229,18 +233,20 @@ module attributes {
 
 // CHECK-LABEL: @vector_2d
 func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
+  %0 = math.atan %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
-  %0 = math.cos %arg0 : vector<2x2xf32>
+  %1 = math.cos %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
-  %1 = math.exp %arg0 : vector<2x2xf32>
+  %2 = math.exp %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
-  %2 = math.absf %arg0 : vector<2x2xf32>
+  %3 = math.absf %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
-  %3 = math.ceil %arg0 : vector<2x2xf32>
+  %4 = math.ceil %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
-  %4 = math.floor %arg0 : vector<2x2xf32>
+  %5 = math.floor %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
-  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
   // CHECK-NEXT: return
   return
 }
@@ -249,18 +255,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
 
 // CHECK-LABEL: @tensor_1d
 func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
+  %0 = math.atan %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
-  %0 = math.cos %arg0 : tensor<2xf32>
+  %1 = math.cos %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
-  %1 = math.exp %arg0 : tensor<2xf32>
+  %2 = math.exp %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
-  %2 = math.absf %arg0 : tensor<2xf32>
+  %3 = math.absf %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
-  %3 = math.ceil %arg0 : tensor<2xf32>
+  %4 = math.ceil %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
-  %4 = math.floor %arg0 : tensor<2xf32>
+  %5 = math.floor %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
-  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  %6 = math.powf %arg0, %arg0 : tensor<2xf32>
   // CHECK-NEXT: return
   return
 }
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 03a5c2f2bc9b15..e9ca838354c0de 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -4,83 +4,91 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kerne
 
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spirv.CL.atan %{{.*}}: f32
+  %0 = math.atan %arg0 : f32
   // CHECK: spirv.CL.cos %{{.*}}: f32
-  %0 = math.cos %arg0 : f32
+  %1 = math.cos %arg0 : f32
   // CHECK: spirv.CL.exp %{{.*}}: f32
-  %1 = math.exp %arg0 : f32
+  %2 = math.exp %arg0 : f32
   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : f32
+  %3 = math.expm1 %arg0 : f32
   // CHECK: spirv.CL.log %{{.*}}: f32
-  %3 = math.log %arg0 : f32
+  %4 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
-  %4 = math.log1p %arg0 : f32
+  %5 = math.log1p %arg0 : f32
   // CHECK: spirv.CL.rint %{{.*}}: f32
-  %5 = math.roundeven %arg0 : f32
+  %6 = math.roundeven %arg0 : f32
   // CHECK: spirv.CL.rsqrt %{{.*}}: f32
-  %6 = math.rsqrt %arg0 : f32
+  %7 = math.rsqrt %arg0 : f32
   // CHECK: spirv.CL.sqrt %{{.*}}: f32
-  %7 = math.sqrt %arg0 : f32
+  %8 = math.sqrt %arg0 : f32
   // CHECK: spirv.CL.tanh %{{.*}}: f32
-  %8 = math.tanh %arg0 : f32
+  %9 = math.tanh %arg0 : f32
   // CHECK: spirv.CL.sin %{{.*}}: f32
-  %9 = math.sin %arg0 : f32
+  %10 = math.sin %arg0 : f32
   // CHECK: spirv.CL.fabs %{{.*}}: f32
-  %10 = math.absf %arg0 : f32
+  %11 = math.absf %arg0 : f32
   // CHECK: spirv.CL.ceil %{{.*}}: f32
-  %11 = math.ceil %arg0 : f32
+  %12 = math.ceil %arg0 : f32
   // CHECK: spirv.CL.floor %{{.*}}: f32
-  %12 = math.floor %arg0 : f32
+  %13 = math.floor %arg0 : f32
   // CHECK: spirv.CL.erf %{{.*}}: f32
-  %13 = math.erf %arg0 : f32
+  %14 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
-  %14 = math.round %arg0 : f32
+  %15 = math.round %arg0 : f32
   return
 }
 
 // CHECK-LABEL: @float32_unary_vector
 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
+  // CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
+  %0 = math.atan %arg0 : vector<3xf32>
   // CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
-  %0 = math.cos %arg0 : vector<3xf32>
+  %1 = math.cos %arg0 : vector<3xf32>
   // CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
-  %1 = math.exp %arg0 : vector<3xf32>
+  %2 = math.exp %arg0 : vector<3xf32>
   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : vector<3xf32>
+  %3 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
-  %3 = math.log %arg0 : vector<3xf32>
+  %4 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
-  %4 = math.log1p %arg0 : vector<3xf32>
+  %5 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
-  %5 = math.roundeven %arg0 : vector<3xf32>
+  %6 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
-  %6 = math.rsqrt %arg0 : vector<3xf32>
+  %7 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
-  %7 = math.sqrt %arg0 : vector<3xf32>
+  %8 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
-  %8 = math.tanh %arg0 : vector<3xf32>
+  %9 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
-  %9 = math.sin %arg0 : vector<3xf32>
+  %10 = math.sin %arg0 : vector<3xf32>
   return
 }
 
 // CHECK-LABEL: @float32_binary_scalar
 func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+  // CHECK: spirv.CL.atan2 %{{.*}}: f32
+  %0 = math.atan2 %lhs, %rhs : f32
   // CHECK: spirv.CL.pow %{{.*}}: f32
-  %0 = math.powf %lhs, %rhs : f32
+  %1 = math.powf %lhs, %rhs : f32
   return
 }
 
 // CHECK-LABEL: @float32_binary_vector
 func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
+  // CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
+  %0 = math.atan2 %lhs, %rhs : vector<4xf32>
   // CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
-  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  %1 = math.powf %lhs, %rhs : vector<4xf32>
   return
 }
 
@@ -118,18 +126,20 @@ module attributes {
 
 // CHECK-LABEL: @vector_2d
 func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
+  %0 = math.atan %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
-  %0 = math.cos %arg0 : vector<2x2xf32>
+  %1 = math.cos %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
-  %1 = math.exp %arg0 : vector<2x2xf32>
+  %2 = math.exp %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
-  %2 = math.absf %arg0 : vector<2x2xf32>
+  %3 = math.absf %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
-  %3 = math.ceil %arg0 : vector<2x2xf32>
+  %4 = math.ceil %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
-  %4 = math.floor %arg0 : vector<2x2xf32>
+  %5 = math.floor %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
-  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
   // CHECK-NEXT: return
   return
 }
@@ -138,18 +148,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
 
 // CHECK-LABEL: @tensor_1d
 func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
+  %0 = math.atan %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
-  %0 = math.cos %arg0 : tensor<2xf32>
+  %1 = math.cos %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
-  %1 = math.exp %arg0 : tensor<2xf32>
+  %2 = math.exp %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
-  %2 = math.absf %arg0 : tensor<2xf32>
+  %3 = math.absf %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
-  %3 = math.ceil %arg0 : tensor<2xf32>
+  %4 = math.ceil %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
-  %4 = math.floor %arg0 : tensor<2xf32>
+  %5 = math.floor %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
-  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  %6 = math.powf %arg0, %arg0 : tensor<2xf32>
   // CHECK-NEXT: return
   return
 }



More information about the Mlir-commits mailing list