[Mlir-commits] [mlir] [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV Conversion pass (PR #102633)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 9 08:56:53 PDT 2024
https://github.com/meehatpa created https://github.com/llvm/llvm-project/pull/102633
Add missing math.atan to spirv.CL.atan and math.atan2 to spirv.CL.atan2 in MathToSPIRV.
>From bcbf003aed51727fa13635932a2cad65f460dc74 Mon Sep 17 00:00:00 2001
From: Gune S <gune30 at gmail.com>
Date: Fri, 9 Aug 2024 21:11:16 +0530
Subject: [PATCH] [mlir][spirv] Add atan and atan2 pattern to MathToSPIRV
Conversion pass
Add missing math.atan to spirv.CL.atan and math.atan2 to spirv.CL.atan2
in MathToSPIRV.
---
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 2 +
.../MathToSPIRV/math-to-opencl-spirv.mlir | 90 +++++++++++--------
2 files changed, 53 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 0b29c93e2d8909..1260b27a0a751f 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -431,6 +431,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
+ CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
+ CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 03a5c2f2bc9b15..e9ca838354c0de 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -4,83 +4,91 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kerne
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {
+ // CHECK: spirv.CL.atan %{{.*}}: f32
+ %0 = math.atan %arg0 : f32
// CHECK: spirv.CL.cos %{{.*}}: f32
- %0 = math.cos %arg0 : f32
+ %1 = math.cos %arg0 : f32
// CHECK: spirv.CL.exp %{{.*}}: f32
- %1 = math.exp %arg0 : f32
+ %2 = math.exp %arg0 : f32
// CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
- %2 = math.expm1 %arg0 : f32
+ %3 = math.expm1 %arg0 : f32
// CHECK: spirv.CL.log %{{.*}}: f32
- %3 = math.log %arg0 : f32
+ %4 = math.log %arg0 : f32
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
- %4 = math.log1p %arg0 : f32
+ %5 = math.log1p %arg0 : f32
// CHECK: spirv.CL.rint %{{.*}}: f32
- %5 = math.roundeven %arg0 : f32
+ %6 = math.roundeven %arg0 : f32
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
- %6 = math.rsqrt %arg0 : f32
+ %7 = math.rsqrt %arg0 : f32
// CHECK: spirv.CL.sqrt %{{.*}}: f32
- %7 = math.sqrt %arg0 : f32
+ %8 = math.sqrt %arg0 : f32
// CHECK: spirv.CL.tanh %{{.*}}: f32
- %8 = math.tanh %arg0 : f32
+ %9 = math.tanh %arg0 : f32
// CHECK: spirv.CL.sin %{{.*}}: f32
- %9 = math.sin %arg0 : f32
+ %10 = math.sin %arg0 : f32
// CHECK: spirv.CL.fabs %{{.*}}: f32
- %10 = math.absf %arg0 : f32
+ %11 = math.absf %arg0 : f32
// CHECK: spirv.CL.ceil %{{.*}}: f32
- %11 = math.ceil %arg0 : f32
+ %12 = math.ceil %arg0 : f32
// CHECK: spirv.CL.floor %{{.*}}: f32
- %12 = math.floor %arg0 : f32
+ %13 = math.floor %arg0 : f32
// CHECK: spirv.CL.erf %{{.*}}: f32
- %13 = math.erf %arg0 : f32
+ %14 = math.erf %arg0 : f32
// CHECK: spirv.CL.round %{{.*}}: f32
- %14 = math.round %arg0 : f32
+ %15 = math.round %arg0 : f32
return
}
// CHECK-LABEL: @float32_unary_vector
func.func @float32_unary_vector(%arg0: vector<3xf32>) {
+ // CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
+ %0 = math.atan %arg0 : vector<3xf32>
// CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
- %0 = math.cos %arg0 : vector<3xf32>
+ %1 = math.cos %arg0 : vector<3xf32>
// CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
- %1 = math.exp %arg0 : vector<3xf32>
+ %2 = math.exp %arg0 : vector<3xf32>
// CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
- %2 = math.expm1 %arg0 : vector<3xf32>
+ %3 = math.expm1 %arg0 : vector<3xf32>
// CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
- %3 = math.log %arg0 : vector<3xf32>
+ %4 = math.log %arg0 : vector<3xf32>
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
- %4 = math.log1p %arg0 : vector<3xf32>
+ %5 = math.log1p %arg0 : vector<3xf32>
// CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
- %5 = math.roundeven %arg0 : vector<3xf32>
+ %6 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
- %6 = math.rsqrt %arg0 : vector<3xf32>
+ %7 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
- %7 = math.sqrt %arg0 : vector<3xf32>
+ %8 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
- %8 = math.tanh %arg0 : vector<3xf32>
+ %9 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
- %9 = math.sin %arg0 : vector<3xf32>
+ %10 = math.sin %arg0 : vector<3xf32>
return
}
// CHECK-LABEL: @float32_binary_scalar
func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+ // CHECK: spirv.CL.atan2 %{{.*}}: f32
+ %0 = math.atan2 %lhs, %rhs : f32
// CHECK: spirv.CL.pow %{{.*}}: f32
- %0 = math.powf %lhs, %rhs : f32
+ %1 = math.powf %lhs, %rhs : f32
return
}
// CHECK-LABEL: @float32_binary_vector
func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
+ // CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
+ %0 = math.atan2 %lhs, %rhs : vector<4xf32>
// CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
- %0 = math.powf %lhs, %rhs : vector<4xf32>
+ %1 = math.powf %lhs, %rhs : vector<4xf32>
return
}
@@ -118,18 +126,20 @@ module attributes {
// CHECK-LABEL: @vector_2d
func.func @vector_2d(%arg0: vector<2x2xf32>) {
+ // CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
+ %0 = math.atan %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
- %0 = math.cos %arg0 : vector<2x2xf32>
+ %1 = math.cos %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
- %1 = math.exp %arg0 : vector<2x2xf32>
+ %2 = math.exp %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
- %2 = math.absf %arg0 : vector<2x2xf32>
+ %3 = math.absf %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
- %3 = math.ceil %arg0 : vector<2x2xf32>
+ %4 = math.ceil %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
- %4 = math.floor %arg0 : vector<2x2xf32>
+ %5 = math.floor %arg0 : vector<2x2xf32>
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
- %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+ %6 = math.powf %arg0, %arg0 : vector<2x2xf32>
// CHECK-NEXT: return
return
}
@@ -138,18 +148,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
// CHECK-LABEL: @tensor_1d
func.func @tensor_1d(%arg0: tensor<2xf32>) {
+ // CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
+ %0 = math.atan %arg0 : tensor<2xf32>
// CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
- %0 = math.cos %arg0 : tensor<2xf32>
+ %1 = math.cos %arg0 : tensor<2xf32>
// CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
- %1 = math.exp %arg0 : tensor<2xf32>
+ %2 = math.exp %arg0 : tensor<2xf32>
// CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
- %2 = math.absf %arg0 : tensor<2xf32>
+ %3 = math.absf %arg0 : tensor<2xf32>
// CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
- %3 = math.ceil %arg0 : tensor<2xf32>
+ %4 = math.ceil %arg0 : tensor<2xf32>
// CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
- %4 = math.floor %arg0 : tensor<2xf32>
+ %5 = math.floor %arg0 : tensor<2xf32>
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
- %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+ %6 = math.powf %arg0, %arg0 : tensor<2xf32>
// CHECK-NEXT: return
return
}
More information about the Mlir-commits
mailing list