[Mlir-commits] [mlir] [MLIR] Separate the scalarization part of MathToROCDL (PR #128203)

Benoit Jacob llvmlistbot at llvm.org
Fri Feb 21 09:13:11 PST 2025


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/128203

MathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion.

>From 68409b900f56a373168952bb9dcfb2ea0ad7cb00 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 21 Feb 2025 10:10:51 -0600
Subject: [PATCH] foo

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../mlir/Conversion/MathToROCDL/MathToROCDL.h |  13 +-
 .../Conversion/MathToROCDL/MathToROCDL.cpp    | 190 ++++++++++++------
 .../Conversion/MathToROCDL/math-to-rocdl.mlir |  17 ++
 3 files changed, 153 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..7d5c487a9dbff 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -18,9 +18,18 @@ class Pass;
 #define GEN_PASS_DECL_CONVERTMATHTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
 
+enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };
+
 /// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns);
+///
+/// Note that the default parameter value MathToROCDLConversionPatternKind::All
+/// is only for compatibility but is not recommended, because lumping together
+/// multiple conversion patters in the same pattern application can result in
+/// type conversion failures when one of the patterns failed.
+void populateMathToROCDLConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    MathToROCDLConversionPatternKind patternKind =
+        MathToROCDLConversionPatternKind::All);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 838eef30a938f..bd8578d70c260 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -37,16 +37,25 @@ using namespace mlir;
 
 template <typename OpTy>
 static void populateOpPatterns(const LLVMTypeConverter &converter,
-                               RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func, StringRef f16Func,
+                               RewritePatternSet &patterns,
+                               MathToROCDLConversionPatternKind patternKind,
+                               StringRef f32Func, StringRef f64Func,
+                               StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc, f16Func);
+  if (patternKind == MathToROCDLConversionPatternKind::All ||
+      patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
+    patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  }
+  if (patternKind == MathToROCDLConversionPatternKind::All ||
+      patternKind == MathToROCDLConversionPatternKind::Lowerings) {
+    patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                             f32ApproxFunc, f16Func);
+  }
 }
 
 void mlir::populateMathToROCDLConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    MathToROCDLConversionPatternKind patternKind) {
   // Handled by mathToLLVM: math::AbsIOp
   // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
-  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64", "__ocml_acos_f16");
-  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
-  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64", "__ocml_asin_f16");
-  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64", "__ocml_atan_f16");
-  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64", "__ocml_cos_f16");
-  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
-  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
-                                  "__ocml_exp_f16");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64", "__ocml_floor_f16");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
-                                  "__ocml_log_f16");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64", "__ocml_log10_f16");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64", "__ocml_log2_f16");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64", "__ocml_pow_f16");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64", "__ocml_sin_f16");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64", "__ocml_tan_f16");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64", "__ocml_erf_f16");
-  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
-                                    "__ocml_pown_f64", "__ocml_pown_f16");
+  populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
+                                   "__ocml_acos_f32", "__ocml_acos_f64",
+                                   "__ocml_acos_f16");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
+                                    "__ocml_acosh_f32", "__ocml_acosh_f64",
+                                    "__ocml_acosh_f16");
+  populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
+                                   "__ocml_asin_f32", "__ocml_asin_f64",
+                                   "__ocml_asin_f16");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
+                                    "__ocml_asinh_f32", "__ocml_asinh_f64",
+                                    "__ocml_asinh_f16");
+  populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
+                                   "__ocml_atan_f32", "__ocml_atan_f64",
+                                   "__ocml_atan_f16");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
+                                    "__ocml_atanh_f32", "__ocml_atanh_f64",
+                                    "__ocml_atanh_f16");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
+                                    "__ocml_atan2_f32", "__ocml_atan2_f64",
+                                    "__ocml_atan2_f16");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
+                                   "__ocml_cbrt_f32", "__ocml_cbrt_f64",
+                                   "__ocml_cbrt_f16");
+  populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
+                                   "__ocml_ceil_f32", "__ocml_ceil_f64",
+                                   "__ocml_ceil_f16");
+  populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
+                                  "__ocml_cos_f32", "__ocml_cos_f64",
+                                  "__ocml_cos_f16");
+  populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
+                                   "__ocml_cosh_f32", "__ocml_cosh_f64",
+                                   "__ocml_cosh_f16");
+  populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
+                                   "__ocml_sinh_f32", "__ocml_sinh_f64",
+                                   "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
+                                  "__ocml_exp_f64", "__ocml_exp_f16");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
+                                   "__ocml_exp2_f32", "__ocml_exp2_f64",
+                                   "__ocml_exp2_f16");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
+                                    "__ocml_expm1_f32", "__ocml_expm1_f64",
+                                    "__ocml_expm1_f16");
+  populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
+                                    "__ocml_floor_f32", "__ocml_floor_f64",
+                                    "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
+                                  "__ocml_log_f64", "__ocml_log_f16");
+  populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
+                                    "__ocml_log10_f32", "__ocml_log10_f64",
+                                    "__ocml_log10_f16");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
+                                    "__ocml_log1p_f32", "__ocml_log1p_f64",
+                                    "__ocml_log1p_f16");
+  populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
+                                   "__ocml_log2_f32", "__ocml_log2_f64",
+                                   "__ocml_log2_f16");
+  populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
+                                   "__ocml_pow_f32", "__ocml_pow_f64",
+                                   "__ocml_pow_f16");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
+                                    "__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
+                                    "__ocml_rsqrt_f16");
+  populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
+                                  "__ocml_sin_f32", "__ocml_sin_f64",
+                                  "__ocml_sin_f16");
+  populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
+                                   "__ocml_tanh_f32", "__ocml_tanh_f64",
+                                   "__ocml_tanh_f16");
+  populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
+                                  "__ocml_tan_f32", "__ocml_tan_f64",
+                                  "__ocml_tan_f16");
+  populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
+                                  "__ocml_erf_f32", "__ocml_erf_f64",
+                                  "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
+                                    "__ocml_pown_f32", "__ocml_pown_f64",
+                                    "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
+  populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
+                                    "__ocml_fmod_f32", "__ocml_fmod_f64",
+                                    "__ocml_fmod_f16");
 }
 
 namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
   auto m = getOperation();
   MLIRContext *ctx = m.getContext();
 
-  RewritePatternSet patterns(&getContext());
   LowerToLLVMOptions options(ctx, DataLayout(m));
   LLVMTypeConverter converter(ctx, options);
-  populateMathToROCDLConversionPatterns(converter, patterns);
+
+  // The two pattern applications below will use distinct ConversionTarget's,
+  // but this is the common denominator.
   ConversionTarget target(getContext());
   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
                          vector::VectorDialect, LLVM::LLVMDialect>();
+
+  // Perform the scalarizations. This is done in a separate pattern application
+  // to ensure that scalarizations are done regardless of lowerings. It is
+  // normal for some lowerings may fail to apply, when we purposely do not lower
+  // a math op to a function call.
+  RewritePatternSet scalarizationPatterns(&getContext());
+  ConversionTarget scalarizationTarget(target);
+  // Math ops are legal if their operands are not vectors.
+  scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
+      [&](Operation *op) {
+        return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
+      });
+  populateMathToROCDLConversionPatterns(
+      converter, scalarizationPatterns,
+      MathToROCDLConversionPatternKind::Scalarizations);
+  if (failed(applyPartialConversion(m, scalarizationTarget,
+                                    std::move(scalarizationPatterns))))
+    signalPassFailure();
+
+  // Perform the lowerings. The ops that must lower to function calls become
+  // illegal.
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
                       LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
                       LLVM::SqrtOp>();
-  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+  RewritePatternSet loweringPatterns(&getContext());
+  populateMathToROCDLConversionPatterns(
+      converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
+  if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 313d7b086731e..44ee2fcbcb7f8 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -578,3 +578,20 @@ module @test_module {
     func.return %result : vector<2x2xf16>
   }
 }
+
+// -----
+
+module @test_module {
+  // This test case covers the case of math ops that do not have a lowering to
+  // a function call. When lowerings to call were lumped together with
+  // scalarization in the same pattern application, they were preventing
+  // scalarization.
+  // CHECK-LABEL: func @math_log_f32_vector_0d
+  func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
+    // CHECK: llvm.extractelement {{.*}} : vector<1xf32>
+    // CHECK: math.log {{.*}} : f32
+    // CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+    %result = math.log %arg : vector<f32>
+    func.return %result : vector<f32>
+  }
+}



More information about the Mlir-commits mailing list