[Mlir-commits] [mlir] [MLIR] Separate the scalarization part of MathToROCDL (PR #128203)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 09:13:46 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benoit Jacob (bjacob)
<details>
<summary>Changes</summary>
MathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion.
---
Full diff: https://github.com/llvm/llvm-project/pull/128203.diff
3 Files Affected:
- (modified) mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h (+11-2)
- (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+125-65)
- (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+17)
``````````diff
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..7d5c487a9dbff 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -18,9 +18,18 @@ class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"
+enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };
+
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+///
+/// Note that the default parameter value MathToROCDLConversionPatternKind::All
+/// is only for compatibility but is not recommended, because lumping together
+/// multiple conversion patters in the same pattern application can result in
+/// type conversion failures when one of the patterns failed.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind =
+ MathToROCDLConversionPatternKind::All);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 838eef30a938f..bd8578d70c260 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -37,16 +37,25 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func, StringRef f16Func,
+ RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind,
+ StringRef f32Func, StringRef f64Func,
+ StringRef f16Func,
StringRef f32ApproxFunc = "") {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc, f16Func);
+ if (patternKind == MathToROCDLConversionPatternKind::All ||
+ patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ }
+ if (patternKind == MathToROCDLConversionPatternKind::All ||
+ patternKind == MathToROCDLConversionPatternKind::Lowerings) {
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32ApproxFunc, f16Func);
+ }
}
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
- populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
- "__ocml_acos_f64", "__ocml_acos_f16");
- populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
- "__ocml_acosh_f64", "__ocml_acosh_f16");
- populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
- "__ocml_asin_f64", "__ocml_asin_f16");
- populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
- "__ocml_asinh_f64", "__ocml_asinh_f16");
- populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64", "__ocml_atan_f16");
- populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
- "__ocml_atanh_f64", "__ocml_atanh_f16");
- populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64", "__ocml_atan2_f16");
- populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64", "__ocml_cbrt_f16");
- populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64", "__ocml_ceil_f16");
- populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64", "__ocml_cos_f16");
- populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
- "__ocml_cosh_f64", "__ocml_cosh_f16");
- populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
- "__ocml_sinh_f64", "__ocml_sinh_f16");
- populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
- "__ocml_exp_f16");
- populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64", "__ocml_exp2_f16");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64", "__ocml_expm1_f16");
- populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64", "__ocml_floor_f16");
- populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
- "__ocml_log_f16");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64", "__ocml_log10_f16");
- populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64", "__ocml_log1p_f16");
- populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64", "__ocml_log2_f16");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64", "__ocml_pow_f16");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
- populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64", "__ocml_sin_f16");
- populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64", "__ocml_tanh_f16");
- populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64", "__ocml_tan_f16");
- populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64", "__ocml_erf_f16");
- populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
- "__ocml_pown_f64", "__ocml_pown_f16");
+ populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
+ "__ocml_acos_f32", "__ocml_acos_f64",
+ "__ocml_acos_f16");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
+ "__ocml_acosh_f32", "__ocml_acosh_f64",
+ "__ocml_acosh_f16");
+ populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
+ "__ocml_asin_f32", "__ocml_asin_f64",
+ "__ocml_asin_f16");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
+ "__ocml_asinh_f32", "__ocml_asinh_f64",
+ "__ocml_asinh_f16");
+ populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
+ "__ocml_atan_f32", "__ocml_atan_f64",
+ "__ocml_atan_f16");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
+ "__ocml_atanh_f32", "__ocml_atanh_f64",
+ "__ocml_atanh_f16");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
+ "__ocml_atan2_f32", "__ocml_atan2_f64",
+ "__ocml_atan2_f16");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
+ "__ocml_cbrt_f32", "__ocml_cbrt_f64",
+ "__ocml_cbrt_f16");
+ populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
+ "__ocml_ceil_f32", "__ocml_ceil_f64",
+ "__ocml_ceil_f16");
+ populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
+ "__ocml_cos_f32", "__ocml_cos_f64",
+ "__ocml_cos_f16");
+ populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
+ "__ocml_cosh_f32", "__ocml_cosh_f64",
+ "__ocml_cosh_f16");
+ populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
+ "__ocml_sinh_f32", "__ocml_sinh_f64",
+ "__ocml_sinh_f16");
+ populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
+ "__ocml_exp_f64", "__ocml_exp_f16");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
+ "__ocml_exp2_f32", "__ocml_exp2_f64",
+ "__ocml_exp2_f16");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
+ "__ocml_expm1_f32", "__ocml_expm1_f64",
+ "__ocml_expm1_f16");
+ populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
+ "__ocml_floor_f32", "__ocml_floor_f64",
+ "__ocml_floor_f16");
+ populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
+ "__ocml_log_f64", "__ocml_log_f16");
+ populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
+ "__ocml_log10_f32", "__ocml_log10_f64",
+ "__ocml_log10_f16");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
+ "__ocml_log1p_f32", "__ocml_log1p_f64",
+ "__ocml_log1p_f16");
+ populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
+ "__ocml_log2_f32", "__ocml_log2_f64",
+ "__ocml_log2_f16");
+ populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
+ "__ocml_pow_f32", "__ocml_pow_f64",
+ "__ocml_pow_f16");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
+ "__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
+ "__ocml_rsqrt_f16");
+ populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
+ "__ocml_sin_f32", "__ocml_sin_f64",
+ "__ocml_sin_f16");
+ populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
+ "__ocml_tanh_f32", "__ocml_tanh_f64",
+ "__ocml_tanh_f16");
+ populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
+ "__ocml_tan_f32", "__ocml_tan_f64",
+ "__ocml_tan_f16");
+ populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
+ "__ocml_erf_f32", "__ocml_erf_f64",
+ "__ocml_erf_f16");
+ populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
+ "__ocml_pown_f32", "__ocml_pown_f64",
+ "__ocml_pown_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64", "__ocml_fmod_f16");
+ populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
+ "__ocml_fmod_f32", "__ocml_fmod_f64",
+ "__ocml_fmod_f16");
}
namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
- RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ // The two pattern applications below will use distinct ConversionTarget's,
+ // but this is the common denominator.
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
+
+ // Perform the scalarizations. This is done in a separate pattern application
+ // to ensure that scalarizations are done regardless of lowerings. It is
+ // normal for some lowerings may fail to apply, when we purposely do not lower
+ // a math op to a function call.
+ RewritePatternSet scalarizationPatterns(&getContext());
+ ConversionTarget scalarizationTarget(target);
+ // Math ops are legal if their operands are not vectors.
+ scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
+ [&](Operation *op) {
+ return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
+ });
+ populateMathToROCDLConversionPatterns(
+ converter, scalarizationPatterns,
+ MathToROCDLConversionPatternKind::Scalarizations);
+ if (failed(applyPartialConversion(m, scalarizationTarget,
+ std::move(scalarizationPatterns))))
+ signalPassFailure();
+
+ // Perform the lowerings. The ops that must lower to function calls become
+ // illegal.
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
- if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ RewritePatternSet loweringPatterns(&getContext());
+ populateMathToROCDLConversionPatterns(
+ converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
+ if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 313d7b086731e..44ee2fcbcb7f8 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -578,3 +578,20 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}
+
+// -----
+
+module @test_module {
+ // This test case covers the case of math ops that do not have a lowering to
+ // a function call. When lowerings to call were lumped together with
+ // scalarization in the same pattern application, they were preventing
+ // scalarization.
+ // CHECK-LABEL: func @math_log_f32_vector_0d
+ func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
+ // CHECK: llvm.extractelement {{.*}} : vector<1xf32>
+ // CHECK: math.log {{.*}} : f32
+ // CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+ %result = math.log %arg : vector<f32>
+ func.return %result : vector<f32>
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/128203
More information about the Mlir-commits
mailing list