[Mlir-commits] [mlir] 4defac9 - [mlir][GPUToNVVM] Add `benefit` to `populate` functions (#128484)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 08:27:58 PST 2025


Author: Matthias Springer
Date: 2025-02-24T17:27:55+01:00
New Revision: 4defac91dbdf4d54aa40a47851c48e9c587fb7e9

URL: https://github.com/llvm/llvm-project/commit/4defac91dbdf4d54aa40a47851c48e9c587fb7e9
DIFF: https://github.com/llvm/llvm-project/commit/4defac91dbdf4d54aa40a47851c48e9c587fb7e9.diff

LOG: [mlir][GPUToNVVM] Add `benefit` to `populate` functions (#128484)

Certain GPU->NVVM patterns compete with Arith->LLVM patterns. (The ones
that lower to libdevice.) Add an optional `benefit` parameter to all
`populate` functions so that users can give preference to GPU->NVVM
patterns.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index fc7c967f1b62c..4c8abea680b66 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include <memory>
 
 namespace mlir {
@@ -35,18 +36,27 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 /// GPU dialect to NVVM.
 void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
 
+/// Populate patterns that lower certain arith and math dialect ops to
+/// libdevice calls.
+void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
-                                         RewritePatternSet &patterns);
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
 
 /// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm
 /// op that is not available on every GPU.
 void populateGpuSubgroupReduceOpLoweringPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
 void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter,
-                                             RewritePatternSet &patterns);
+                                             RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 61d4ccec5f0bd..15d6e0a069e3e 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -26,6 +26,7 @@ def ApplyGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
     Collects patterns that convert GPU dialect ops to NVVM dialect ops. These
     patterns require an "LLVMTypeConverter".
   }];
+  let arguments = (ins DefaultValuedAttr<I16Attr, "1">:$benefit);
   let assemblyFormat = "attr-dict";
 }
 

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd2fd020f684b..e17b06379988c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -43,8 +43,9 @@ struct GPUDynamicSharedMemoryOpLowering
   using ConvertOpToLLVMPattern<
       gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
   GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
-                                   unsigned alignmentBit = 0)
-      : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+                                   unsigned alignmentBit = 0,
+                                   PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter, benefit),
         alignmentBit(alignmentBit) {}
 
   LogicalResult
@@ -81,8 +82,9 @@ struct GPUFuncOpLoweringOptions {
 
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
   GPUFuncOpLowering(const LLVMTypeConverter &converter,
-                    const GPUFuncOpLoweringOptions &options)
-      : ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
+                    const GPUFuncOpLoweringOptions &options,
+                    PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter, benefit),
         allocaAddrSpace(options.allocaAddrSpace),
         workgroupAddrSpace(options.workgroupAddrSpace),
         kernelAttributeName(options.kernelAttributeName),

diff  --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index c9ddc942bd682..1f158b271e5c6 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -36,14 +36,16 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
   IntrType intrType;
 
 public:
-  explicit OpLowering(const LLVMTypeConverter &typeConverter)
-      : ConvertOpToLLVMPattern<Op>(typeConverter),
+  explicit OpLowering(const LLVMTypeConverter &typeConverter,
+                      PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
         indexBitwidth(typeConverter.getIndexTypeBitwidth()),
         indexKind(IndexKind::Other), intrType(IntrType::None) {}
 
   explicit OpLowering(const LLVMTypeConverter &typeConverter,
-                      IndexKind indexKind, IntrType intrType)
-      : ConvertOpToLLVMPattern<Op>(typeConverter),
+                      IndexKind indexKind, IntrType intrType,
+                      PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
         indexBitwidth(typeConverter.getIndexTypeBitwidth()),
         indexKind(indexKind), intrType(intrType) {}
 

diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 0bc2f697a7662..34150c4d13085 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -57,8 +57,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
                                 StringRef f32Func, StringRef f64Func,
                                 StringRef f32ApproxFunc, StringRef f16Func,
-                                StringRef i32Func = "")
-      : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
+                                StringRef i32Func = "",
+                                PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
         f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
         i32Func(i32Func) {}
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9290279112715..61b73f546b5da 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -378,7 +378,9 @@ struct LowerGpuOpsToNVVMOpsPass final
     RewritePatternSet llvmPatterns(m.getContext());
     LLVMConversionTarget target(getContext());
 
-    populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
+    // Set higher benefit, so patterns will run before generic LLVM lowering.
+    populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
+                                        /*benefit=*/10);
 
     llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
                                                       allowedDialects.end());
@@ -464,78 +466,173 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
 
 template <typename OpTy>
 static void populateOpPatterns(const LLVMTypeConverter &converter,
-                               RewritePatternSet &patterns, StringRef f32Func,
+                               RewritePatternSet &patterns,
+                               PatternBenefit benefit, StringRef f32Func,
                                StringRef f64Func, StringRef f32ApproxFunc = "",
                                StringRef f16Func = "") {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc, f16Func);
+                                           f32ApproxFunc, f16Func,
+                                           /*i32Func=*/"", benefit);
 }
 
 template <typename OpTy>
 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
                                   RewritePatternSet &patterns,
-                                  StringRef i32Func) {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func);
+                                  PatternBenefit benefit, StringRef i32Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
+                                           benefit);
 }
 
 template <typename OpTy>
 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
                                        RewritePatternSet &patterns,
+                                       PatternBenefit benefit,
                                        StringRef f32Func, StringRef f64Func) {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "");
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
+                                           /*i32Func=*/"", benefit);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<GPUSubgroupReduceOpLowering>(converter);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
+}
+
+void mlir::populateLibDeviceConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
+                                    "__nv_fmod");
+  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
+                                       "__nv_fmaxf", "__nv_fmax");
+  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
+                                       "__nv_fminf", "__nv_fmin");
+
+  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
+  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
+                                   "__nv_fabs");
+  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
+                                   "__nv_acos");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
+                                    "__nv_acosh");
+  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
+                                   "__nv_asin");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
+                                    "__nv_asinh");
+  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
+                                   "__nv_atan");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
+                                    "__nv_atan2");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
+                                    "__nv_atanh");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
+                                   "__nv_cbrt");
+  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
+                                   "__nv_ceil");
+  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+                                       "__nv_copysignf", "__nv_copysign");
+  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
+                                  "__nv_cos", "__nv_fast_cosf");
+  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
+                                   "__nv_cosh");
+  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
+                                  "__nv_erf");
+  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
+                                  "__nv_exp", "__nv_fast_expf");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
+                                   "__nv_exp2");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
+                                    "__nv_expm1");
+  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
+                                    "__nv_floor");
+  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
+                                  "__nv_fma");
+  // Note: libdevice does not provide `__nv_isfinitef` as of moment of writing.
+  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit, "",
+                                       "__nv_isfinited");
+  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
+                                    "__nv_isinfd");
+  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
+                                    "__nv_isnand");
+  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
+                                  "__nv_log", "__nv_fast_logf");
+  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
+                                    "__nv_log10", "__nv_fast_log10f");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
+                                    "__nv_log1p");
+  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
+                                   "__nv_log2", "__nv_fast_log2f");
+  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
+                                   "__nv_pow", "__nv_fast_powf");
+  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
+                                            "__nv_powif", "__nv_powi");
+  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
+                                    "__nv_round");
+  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
+                                        "__nv_rintf", "__nv_rint");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
+                                    "__nv_rsqrt");
+  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
+                                  "__nv_sin", "__nv_fast_sinf");
+  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
+                                   "__nv_sinh");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
+                                   "__nv_sqrt");
+  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
+                                  "__nv_tan", "__nv_fast_tanf");
+  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
+                                   "__nv_tanh");
 }
 
 void mlir::populateGpuToNVVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
+
+  // TODO: Pass benefit to generated patterns.
   populateWithGenerated(patterns);
 
-  // Set higher benefit, so patterns will run before generic LLVM lowering.
   patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
-      converter, /*benefit*/ 10);
+      converter, benefit);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
-      converter, IndexKind::Block, IntrType::Id);
+      converter, IndexKind::Block, IntrType::Id, benefit);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
                                       NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
-      converter, IndexKind::Block, IntrType::Dim);
+      converter, IndexKind::Block, IntrType::Dim, benefit);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
                                       NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
-      converter, IndexKind::Other, IntrType::Id);
+      converter, IndexKind::Other, IntrType::Id, benefit);
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
-      NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
+      NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
+                            benefit);
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
       NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
-      converter, IndexKind::Other, IntrType::Id);
+      converter, IndexKind::Other, IntrType::Id, benefit);
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
       NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
-      converter, IndexKind::Other, IntrType::Dim);
+      converter, IndexKind::Other, IntrType::Dim, benefit);
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
-      converter, IndexKind::Grid, IntrType::Id);
+      converter, IndexKind::Grid, IntrType::Id, benefit);
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
-      converter, IndexKind::Grid, IntrType::Dim);
+      converter, IndexKind::Grid, IntrType::Dim, benefit);
   patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
-      converter);
+      converter, benefit);
 
   patterns.add<GPUDynamicSharedMemoryOpLowering>(
-      converter, NVVM::kSharedMemoryAlignmentBit);
+      converter, NVVM::kSharedMemoryAlignmentBit, benefit);
 
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
@@ -549,87 +646,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
           StringAttr::get(&converter.getContext(),
                           NVVM::NVVMDialect::getKernelFuncAttrName()),
           StringAttr::get(&converter.getContext(),
-                          NVVM::NVVMDialect::getMaxntidAttrName())});
-
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
-                                    "__nv_fmod");
-  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, "__nv_fmaxf",
-                                       "__nv_fmax");
-  populateOpPatterns<arith::MinNumFOp>(converter, patterns, "__nv_fminf",
-                                       "__nv_fmin");
+                          NVVM::NVVMDialect::getMaxntidAttrName())},
+      benefit);
 
-  populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
-                                   "__nv_fabs");
-  populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
-                                   "__nv_acos");
-  populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
-                                    "__nv_acosh");
-  populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
-                                   "__nv_asin");
-  populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
-                                    "__nv_asinh");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
-                                   "__nv_atan");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
-                                    "__nv_atan2");
-  populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
-                                    "__nv_atanh");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
-                                   "__nv_cbrt");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
-                                   "__nv_ceil");
-  populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
-                                       "__nv_copysign");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
-                                  "__nv_fast_cosf");
-  populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
-                                   "__nv_cosh");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
-                                  "__nv_fast_expf");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
-                                   "__nv_exp2");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
-                                    "__nv_expm1");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
-                                    "__nv_floor");
-  populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
-  // Note: libdevice does not provide `__nv_isfinitef` as of moment of writing.
-  populateOpPatterns<math::IsFiniteOp>(converter, patterns, "",
-                                       "__nv_isfinited");
-  populateOpPatterns<math::IsInfOp>(converter, patterns, "__nv_isinff",
-                                    "__nv_isinfd");
-  populateOpPatterns<math::IsNaNOp>(converter, patterns, "__nv_isnanf",
-                                    "__nv_isnand");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
-                                  "__nv_fast_logf");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
-                                    "__nv_log10", "__nv_fast_log10f");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
-                                    "__nv_log1p");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
-                                   "__nv_log2", "__nv_fast_log2f");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
-                                   "__nv_fast_powf");
-  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
-                                            "__nv_powi");
-  populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
-                                    "__nv_round");
-  populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
-                                        "__nv_rint");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
-                                    "__nv_rsqrt");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
-                                  "__nv_fast_sinf");
-  populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
-                                   "__nv_sinh");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
-                                   "__nv_sqrt");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
-                                  "__nv_fast_tanf");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
-                                   "__nv_tanh");
+  populateLibDeviceConversionPatterns(converter, patterns, benefit);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4b50b9187b25b..4bd94bcebf290 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -388,8 +388,9 @@ LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
 }
 
 void mlir::populateGpuWMMAToNVVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
   patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
                WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
-               WmmaElementwiseOpToNVVMLowering>(converter);
+               WmmaElementwiseOpToNVVMLowering>(converter, benefit);
 }

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 1528da914d546..0737b2e22ebb8 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -86,7 +86,9 @@ void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
   // TODO: We should have a single to_nvvm_type_converter.
   llvmTypeConverter.addConversion(
       [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
-  populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
+  // Set higher benefit, so patterns will run before generic LLVM lowering.
+  populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns,
+                                      getBenefit());
 }
 
 LogicalResult

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
index 1a22ba662cbf7..8bcbdad1437ab 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
@@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_conversion_patterns.vector.vector_to_llvm
       transform.apply_conversion_patterns.func.func_to_llvm
       transform.apply_conversion_patterns.dialect_to_llvm "memref"
-      transform.apply_conversion_patterns.gpu.gpu_to_nvvm
+      transform.apply_conversion_patterns.gpu.gpu_to_nvvm {benefit = 10 : i16}
       transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
       transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm {has_redux = true}
       transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 664a0bb0c0d5b..7b5b11ec02724 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1014,7 +1014,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_conversion_patterns.vector.vector_to_llvm
       transform.apply_conversion_patterns.func.func_to_llvm
       transform.apply_conversion_patterns.dialect_to_llvm "memref"
-      transform.apply_conversion_patterns.gpu.gpu_to_nvvm
+      transform.apply_conversion_patterns.gpu.gpu_to_nvvm {benefit = 10 : i16}
       transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm
       transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm
       transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm


        


More information about the Mlir-commits mailing list