[Mlir-commits] [mlir] [MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend (PR #102971)

Jan Leyonberg llvmlistbot at llvm.org
Mon Aug 12 13:12:41 PDT 2024


https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/102971

This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.

>From 8fe4b0e0ceb2003935bdf3f9e8197ef6ff90c279 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Mon, 12 Aug 2024 16:04:39 -0400
Subject: [PATCH] [MLIR][ROCDL] Remove patterns for ops supported as intrinsics
 in the AMDGPU backend

This patch removes pattens for a few operations which allows mathToLLVM
conversion to convert the operations into LLVM intrinsics instead since they
are supported directly by the AMDGPU backend.
---
 .../Conversion/MathToROCDL/MathToROCDL.cpp    |  12 +--
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 102 +++---------------
 .../Conversion/MathToROCDL/math-to-rocdl.mlir |  60 -----------
 3 files changed, 21 insertions(+), 153 deletions(-)

diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 7de6971ba2ee72..fd4eab0e10d67e 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::AbsFIOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
   // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::ExpOp
   // Handled by mathToLLVM: math::FmaOp
+  // Handled by mathToLLVM: math::LogOp
   // FIXME: math::IPowIOp
   // FIXME: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
                                    "__ocml_acos_f64");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                    "__ocml_cosh_f64");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
                                    "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
                                    "__ocml_exp2_f64");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
                                     "__ocml_expm1_f64");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
                                     "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
                                     "__ocml_log10_f64");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                     "__ocml_rsqrt_f64");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
                                   "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
                                    "__ocml_tanh_f64");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bf49a42a115775..4f1f26e8794d9e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -131,21 +131,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_fabs
-  func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %exp_f32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result32 = math.exp %exp_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
 }
 
 // -----
-
 // Test that we handled properly operation with SymbolTable other than module op
 gpu.module @test_module {
   "test.symbol_scope"() ({
     // CHECK: test.symbol_scope
-    // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-    // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-    // CHECK-LABEL: func @gpu_exp
-    func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-      %exp_f32 = math.exp %arg_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result32 = math.exp %exp_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result64 = math.exp %arg_f64 : f64
-      // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+    // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+    // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+    // CHECK-LABEL: func @gpu_sin
+    func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+      %sin_f32 = math.sin %arg_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result32 = math.sin %sin_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result64 = math.sin %arg_f64 : f64
+      // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
       func.return %result32, %result64 : f32, f64
     }
     "test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_log
-  func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_sqrt
-  func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
-      -> (f16, f32, f64) {
-    %result16 = math.sqrt %arg_f16 : f16
-    // CHECK: llvm.fpext %{{.*}} : f16 to f32
-    // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result16, %result32, %result64 : f16, f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_unroll
   func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
-    %result = math.exp %arg0 : vector<4xf32>
+    %result = math.sin %arg0 : vector<4xf32>
     // CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
     // CHECK: return %[[V4]]
     func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
 
 gpu.module @module {
 // CHECK-LABEL: @spirv_exp
-// CHECK: llvm.call @__ocml_exp_f32
+// CHECK: llvm.call @__ocml_sin_f32
   spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
-    %0 = math.exp %arg0 : vector<4xf32>
+    %0 = math.sin %arg0 : vector<4xf32>
     spirv.ReturnValue %0 : vector<4xf32>
   }
 }
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index a406ec45a7f109..9a05a94f9f1ac7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -15,21 +15,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @math_absf
-  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @math_exp
-  func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @math_log
-  func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @math_sqrt
-  func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64



More information about the Mlir-commits mailing list