[Mlir-commits] [mlir] 49777d7 - [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV Conversion pass (#102633)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 12 07:58:52 PDT 2024


Author: meehatpa
Date: 2024-08-12T10:58:48-04:00
New Revision: 49777d7ffe82f1dcace318e51c9d785994f8c32a

URL: https://github.com/llvm/llvm-project/commit/49777d7ffe82f1dcace318e51c9d785994f8c32a
DIFF: https://github.com/llvm/llvm-project/commit/49777d7ffe82f1dcace318e51c9d785994f8c32a.diff

LOG: [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV Conversion pass (#102633)

Add missing math.atan to spirv.CL.atan and math.atan2 to spirv.CL.atan2
in MathToSPIRV.
Add math.atan to spirv.GL.atan too.

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 0b29c93e2d8909..5b3c2fb15e7026 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -414,6 +414,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
            CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
            CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
+           CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
            CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
            CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
            CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
@@ -431,6 +432,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
                CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
                CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
+               CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
+               CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
                CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
                CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
                CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 4d0ef06d7e92f9..a9397667393429 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -6,65 +6,69 @@ module attributes {
 
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spirv.GL.Atan %{{.*}}: f32
+  %0 = math.atan %arg0 : f32
   // CHECK: spirv.GL.Cos %{{.*}}: f32
-  %0 = math.cos %arg0 : f32
+  %1 = math.cos %arg0 : f32
   // CHECK: spirv.GL.Exp %{{.*}}: f32
-  %1 = math.exp %arg0 : f32
+  %2 = math.exp %arg0 : f32
   // CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : f32
+  %3 = math.expm1 %arg0 : f32
   // CHECK: spirv.GL.Log %{{.*}}: f32
-  %3 = math.log %arg0 : f32
+  %4 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
-  %4 = math.log1p %arg0 : f32
+  %5 = math.log1p %arg0 : f32
   // CHECK: spirv.GL.RoundEven %{{.*}}: f32
-  %5 = math.roundeven %arg0 : f32
+  %6 = math.roundeven %arg0 : f32
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
-  %6 = math.rsqrt %arg0 : f32
+  %7 = math.rsqrt %arg0 : f32
   // CHECK: spirv.GL.Sqrt %{{.*}}: f32
-  %7 = math.sqrt %arg0 : f32
+  %8 = math.sqrt %arg0 : f32
   // CHECK: spirv.GL.Tanh %{{.*}}: f32
-  %8 = math.tanh %arg0 : f32
+  %9 = math.tanh %arg0 : f32
   // CHECK: spirv.GL.Sin %{{.*}}: f32
-  %9 = math.sin %arg0 : f32
+  %10 = math.sin %arg0 : f32
   // CHECK: spirv.GL.FAbs %{{.*}}: f32
-  %10 = math.absf %arg0 : f32
+  %11 = math.absf %arg0 : f32
   // CHECK: spirv.GL.Ceil %{{.*}}: f32
-  %11 = math.ceil %arg0 : f32
+  %12 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
-  %12 = math.floor %arg0 : f32
+  %13 = math.floor %arg0 : f32
   return
 }
 
 // CHECK-LABEL: @float32_unary_vector
 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
+  // CHECK: spirv.GL.Atan %{{.*}}: vector<3xf32>
+  %0 = math.atan %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Cos %{{.*}}: vector<3xf32>
-  %0 = math.cos %arg0 : vector<3xf32>
+  %1 = math.cos %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Exp %{{.*}}: vector<3xf32>
-  %1 = math.exp %arg0 : vector<3xf32>
+  %2 = math.exp %arg0 : vector<3xf32>
   // CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : vector<3xf32>
+  %3 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Log %{{.*}}: vector<3xf32>
-  %3 = math.log %arg0 : vector<3xf32>
+  %4 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
-  %4 = math.log1p %arg0 : vector<3xf32>
+  %5 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
-  %5 = math.roundeven %arg0 : vector<3xf32>
+  %6 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
-  %6 = math.rsqrt %arg0 : vector<3xf32>
+  %7 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
-  %7 = math.sqrt %arg0 : vector<3xf32>
+  %8 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
-  %8 = math.tanh %arg0 : vector<3xf32>
+  %9 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
-  %9 = math.sin %arg0 : vector<3xf32>
+  %10 = math.sin %arg0 : vector<3xf32>
   return
 }
 
@@ -229,18 +233,20 @@ module attributes {
 
 // CHECK-LABEL: @vector_2d
 func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
+  %0 = math.atan %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
-  %0 = math.cos %arg0 : vector<2x2xf32>
+  %1 = math.cos %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
-  %1 = math.exp %arg0 : vector<2x2xf32>
+  %2 = math.exp %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
-  %2 = math.absf %arg0 : vector<2x2xf32>
+  %3 = math.absf %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
-  %3 = math.ceil %arg0 : vector<2x2xf32>
+  %4 = math.ceil %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
-  %4 = math.floor %arg0 : vector<2x2xf32>
+  %5 = math.floor %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
-  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
   // CHECK-NEXT: return
   return
 }
@@ -249,18 +255,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
 
 // CHECK-LABEL: @tensor_1d
 func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
+  %0 = math.atan %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
-  %0 = math.cos %arg0 : tensor<2xf32>
+  %1 = math.cos %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
-  %1 = math.exp %arg0 : tensor<2xf32>
+  %2 = math.exp %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
-  %2 = math.absf %arg0 : tensor<2xf32>
+  %3 = math.absf %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
-  %3 = math.ceil %arg0 : tensor<2xf32>
+  %4 = math.ceil %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
-  %4 = math.floor %arg0 : tensor<2xf32>
+  %5 = math.floor %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
-  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  %6 = math.powf %arg0, %arg0 : tensor<2xf32>
   // CHECK-NEXT: return
   return
 }

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 03a5c2f2bc9b15..e9ca838354c0de 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -4,83 +4,91 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kerne
 
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spirv.CL.atan %{{.*}}: f32
+  %0 = math.atan %arg0 : f32
   // CHECK: spirv.CL.cos %{{.*}}: f32
-  %0 = math.cos %arg0 : f32
+  %1 = math.cos %arg0 : f32
   // CHECK: spirv.CL.exp %{{.*}}: f32
-  %1 = math.exp %arg0 : f32
+  %2 = math.exp %arg0 : f32
   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : f32
+  %3 = math.expm1 %arg0 : f32
   // CHECK: spirv.CL.log %{{.*}}: f32
-  %3 = math.log %arg0 : f32
+  %4 = math.log %arg0 : f32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
-  %4 = math.log1p %arg0 : f32
+  %5 = math.log1p %arg0 : f32
   // CHECK: spirv.CL.rint %{{.*}}: f32
-  %5 = math.roundeven %arg0 : f32
+  %6 = math.roundeven %arg0 : f32
   // CHECK: spirv.CL.rsqrt %{{.*}}: f32
-  %6 = math.rsqrt %arg0 : f32
+  %7 = math.rsqrt %arg0 : f32
   // CHECK: spirv.CL.sqrt %{{.*}}: f32
-  %7 = math.sqrt %arg0 : f32
+  %8 = math.sqrt %arg0 : f32
   // CHECK: spirv.CL.tanh %{{.*}}: f32
-  %8 = math.tanh %arg0 : f32
+  %9 = math.tanh %arg0 : f32
   // CHECK: spirv.CL.sin %{{.*}}: f32
-  %9 = math.sin %arg0 : f32
+  %10 = math.sin %arg0 : f32
   // CHECK: spirv.CL.fabs %{{.*}}: f32
-  %10 = math.absf %arg0 : f32
+  %11 = math.absf %arg0 : f32
   // CHECK: spirv.CL.ceil %{{.*}}: f32
-  %11 = math.ceil %arg0 : f32
+  %12 = math.ceil %arg0 : f32
   // CHECK: spirv.CL.floor %{{.*}}: f32
-  %12 = math.floor %arg0 : f32
+  %13 = math.floor %arg0 : f32
   // CHECK: spirv.CL.erf %{{.*}}: f32
-  %13 = math.erf %arg0 : f32
+  %14 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
-  %14 = math.round %arg0 : f32
+  %15 = math.round %arg0 : f32
   return
 }
 
 // CHECK-LABEL: @float32_unary_vector
 func.func @float32_unary_vector(%arg0: vector<3xf32>) {
+  // CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
+  %0 = math.atan %arg0 : vector<3xf32>
   // CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
-  %0 = math.cos %arg0 : vector<3xf32>
+  %1 = math.cos %arg0 : vector<3xf32>
   // CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
-  %1 = math.exp %arg0 : vector<3xf32>
+  %2 = math.exp %arg0 : vector<3xf32>
   // CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: spirv.FSub %[[EXP]], %[[ONE]]
-  %2 = math.expm1 %arg0 : vector<3xf32>
+  %3 = math.expm1 %arg0 : vector<3xf32>
   // CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
-  %3 = math.log %arg0 : vector<3xf32>
+  %4 = math.log %arg0 : vector<3xf32>
   // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
-  %4 = math.log1p %arg0 : vector<3xf32>
+  %5 = math.log1p %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
-  %5 = math.roundeven %arg0 : vector<3xf32>
+  %6 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
-  %6 = math.rsqrt %arg0 : vector<3xf32>
+  %7 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
-  %7 = math.sqrt %arg0 : vector<3xf32>
+  %8 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
-  %8 = math.tanh %arg0 : vector<3xf32>
+  %9 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
-  %9 = math.sin %arg0 : vector<3xf32>
+  %10 = math.sin %arg0 : vector<3xf32>
   return
 }
 
 // CHECK-LABEL: @float32_binary_scalar
 func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+  // CHECK: spirv.CL.atan2 %{{.*}}: f32
+  %0 = math.atan2 %lhs, %rhs : f32
   // CHECK: spirv.CL.pow %{{.*}}: f32
-  %0 = math.powf %lhs, %rhs : f32
+  %1 = math.powf %lhs, %rhs : f32
   return
 }
 
 // CHECK-LABEL: @float32_binary_vector
 func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
+  // CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
+  %0 = math.atan2 %lhs, %rhs : vector<4xf32>
   // CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
-  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  %1 = math.powf %lhs, %rhs : vector<4xf32>
   return
 }
 
@@ -118,18 +126,20 @@ module attributes {
 
 // CHECK-LABEL: @vector_2d
 func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
+  %0 = math.atan %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
-  %0 = math.cos %arg0 : vector<2x2xf32>
+  %1 = math.cos %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
-  %1 = math.exp %arg0 : vector<2x2xf32>
+  %2 = math.exp %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
-  %2 = math.absf %arg0 : vector<2x2xf32>
+  %3 = math.absf %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
-  %3 = math.ceil %arg0 : vector<2x2xf32>
+  %4 = math.ceil %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
-  %4 = math.floor %arg0 : vector<2x2xf32>
+  %5 = math.floor %arg0 : vector<2x2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
-  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
   // CHECK-NEXT: return
   return
 }
@@ -138,18 +148,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
 
 // CHECK-LABEL: @tensor_1d
 func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
+  %0 = math.atan %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
-  %0 = math.cos %arg0 : tensor<2xf32>
+  %1 = math.cos %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
-  %1 = math.exp %arg0 : tensor<2xf32>
+  %2 = math.exp %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
-  %2 = math.absf %arg0 : tensor<2xf32>
+  %3 = math.absf %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
-  %3 = math.ceil %arg0 : tensor<2xf32>
+  %4 = math.ceil %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
-  %4 = math.floor %arg0 : tensor<2xf32>
+  %5 = math.floor %arg0 : tensor<2xf32>
   // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
-  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  %6 = math.powf %arg0, %arg0 : tensor<2xf32>
   // CHECK-NEXT: return
   return
 }


        


More information about the Mlir-commits mailing list