[Mlir-commits] [mlir] 3bf1f0e - [mlir] Add missing patterns for lowering to Libm.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Feb 22 06:50:00 PST 2023
Author: Alexander Belyaev
Date: 2023-02-22T15:46:03+01:00
New Revision: 3bf1f0e7530f3cbbe88179f6718a0aed6fb5ff54
URL: https://github.com/llvm/llvm-project/commit/3bf1f0e7530f3cbbe88179f6718a0aed6fb5ff54
DIFF: https://github.com/llvm/llvm-project/commit/3bf1f0e7530f3cbbe88179f6718a0aed6fb5ff54.diff
LOG: [mlir] Add missing patterns for lowering to Libm.
Differential Revision: https://reviews.llvm.org/D144561
Added:
Modified:
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 35ac2b3c2bdfe..1e5b317caa941 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -60,6 +60,14 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
private:
std::string floatFunc, doubleFunc;
};
+
+template <typename OpTy>
+void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
+ StringRef floatFunc, StringRef doubleFunc) {
+ patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
+ patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+}
+
} // namespace
template <typename Op>
@@ -153,35 +161,23 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
- patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
- VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
- VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
- VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
- VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
- VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
- ctx);
- patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
- PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
- PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
- PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
- PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
- PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(ctx);
- patterns.add<ScalarOpToLibmCall<math::AtanOp>>(ctx, "atanf", "atan");
- patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(ctx, "atan2f", "atan2");
- patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(ctx, "cbrtf", "cbrt");
- patterns.add<ScalarOpToLibmCall<math::ErfOp>>(ctx, "erff", "erf");
- patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(ctx, "expm1f", "expm1");
- patterns.add<ScalarOpToLibmCall<math::TanOp>>(ctx, "tanf", "tan");
- patterns.add<ScalarOpToLibmCall<math::TanhOp>>(ctx, "tanhf", "tanh");
- patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(ctx, "roundevenf",
- "roundeven");
- patterns.add<ScalarOpToLibmCall<math::RoundOp>>(ctx, "roundf", "round");
- patterns.add<ScalarOpToLibmCall<math::CosOp>>(ctx, "cosf", "cos");
- patterns.add<ScalarOpToLibmCall<math::SinOp>>(ctx, "sinf", "sin");
- patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(ctx, "log1pf", "log1p");
- patterns.add<ScalarOpToLibmCall<math::FloorOp>>(ctx, "floorf", "floor");
- patterns.add<ScalarOpToLibmCall<math::CeilOp>>(ctx, "ceilf", "ceil");
- patterns.add<ScalarOpToLibmCall<math::TruncOp>>(ctx, "truncf", "trunc");
+
+ populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
+ populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
+ populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
+ populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
+ populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
+ populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
+ populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
+ populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
+ populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
+ populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+ "roundeven");
+ populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
+ populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
+ populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
+ populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
+ populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
}
namespace {
More information about the Mlir-commits
mailing list