[Mlir-commits] [mlir] [mlir][math] Add `clampf` and clean math `ExpandOps` (PR #151153)
Fabian Mora
llvmlistbot at llvm.org
Tue Jul 29 06:54:11 PDT 2025
https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/151153
This patch adds the `clampf` operation to the math dialect. The semantics op are defined as:
```
clampf(x, min_v, max_v) = max(min(x, min_v), max_v)
```
The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, [__saturatef](https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__SINGLE.html#group__cuda__math__intrinsic__single_1ga2c84f08e0db7117a14509d21c3aec04e) in NVIDIA GPUs, or `__builtin_amdgcn_fmed3f` in AMD GPUs.
This patch also removes `test-expand-math` in favor of `math-expand-ops`.
Finally, it removes individual expansion population API calls like `populateExpandCoshPattern` in favor of:
```C++
void populateExpansionPatterns(RewritePatternSet &patterns,
ArrayRef<StringRef> opMnemonics = {});
```
>From 7191a5908889d3785ded0c69c7d6aea5a7696291 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 29 Jul 2025 13:44:01 +0000
Subject: [PATCH] add clampf
---
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 31 ++++
.../mlir/Dialect/Math/Transforms/Passes.h | 26 ++--
.../mlir/Dialect/Math/Transforms/Passes.td | 20 +++
.../Dialect/Math/Transforms/CMakeLists.txt | 2 +-
.../{ExpandPatterns.cpp => ExpandOps.cpp} | 139 ++++++++++--------
mlir/test/Dialect/Math/expand-math.mlir | 35 ++++-
mlir/test/Dialect/Math/ops.mlir | 15 +-
mlir/test/lib/Dialect/Math/CMakeLists.txt | 1 -
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 62 --------
.../mlir-runner/test-expand-math-approx.mlir | 2 +-
mlir/tools/mlir-opt/mlir-opt.cpp | 2 -
11 files changed, 188 insertions(+), 147 deletions(-)
rename mlir/lib/Dialect/Math/Transforms/{ExpandPatterns.cpp => ExpandOps.cpp} (89%)
delete mode 100644 mlir/test/lib/Dialect/Math/TestExpandMath.cpp
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 56370388dea87..cfd8c4b8f11f7 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// ClampFOp
+//===----------------------------------------------------------------------===//
+
+def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
+ let summary = "floating point clamping operation";
+ let description = [{
+ The `clampf` operation takes three operands and returns one result, each of
+ these is required to be the same type. Operands must be of floating point type
+ (i.e., scalar, tensor or vector).
+
+ The semantics of the operation are described by:
+ ```
+ clampf(value, min, max) = maxf(minf(value, min), max)
+ ```
+
+ Example:
+
+ ```mlir
+ %d = math.clampf %value to [%min, %max] : f64
+ ```
+ }];
+ let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let assemblyFormat = [{
+ $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// CopySignOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index c0fe5d3be448a..b3abbf728a3c6 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -23,22 +23,16 @@ class ConversionTarget;
class RewritePatternSet;
class TypeConverter;
-void populateExpandCtlzPattern(RewritePatternSet &patterns);
-void populateExpandTanPattern(RewritePatternSet &patterns);
-void populateExpandSinhPattern(RewritePatternSet &patterns);
-void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
-void populateExpandAsinhPattern(RewritePatternSet &patterns);
-void populateExpandAcoshPattern(RewritePatternSet &patterns);
-void populateExpandAtanhPattern(RewritePatternSet &patterns);
-void populateExpandFmaFPattern(RewritePatternSet &patterns);
-void populateExpandCeilFPattern(RewritePatternSet &patterns);
-void populateExpandExp2FPattern(RewritePatternSet &patterns);
-void populateExpandPowFPattern(RewritePatternSet &patterns);
-void populateExpandFPowIPattern(RewritePatternSet &patterns);
-void populateExpandRoundFPattern(RewritePatternSet &patterns);
-void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
-void populateExpandRsqrtPattern(RewritePatternSet &patterns);
+namespace math {
+/// Adds patterns to expand math operations into other more fundamental
+/// operations. For example, hyperbolic functions are expanded into expressions
+/// using `exp`. If `opMnemonics` is empty then all available patterns will be
+/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
+/// will be added to the set.
+void populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics = {});
+} // namespace math
+
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index a84c89020d4f3..4d415aeac8f58 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathExpandOpsPass : Pass<"math-expand-ops"> {
+ let summary = "Expand math operations.";
+ let description = [{
+ Expands some math operations into more fundamental operations, allowing them
+ to be subsequently lowered through these. For example, hyperbolic functions
+ are transformed into their expanded form containing only `exp` functions.
+
+ The `ops` parameter can be used to apply only a subset of all the
+ available expansions, these must correspond to the operation mnemonic.
+ For example, `ops=sinh,acosh` will expand only `math.sinh` and
+ `math.acosh` operations. If the list is empty, then all expansions are
+ applied.
+ }];
+ let dependentDialects = ["arith::ArithDialect"];
+ let options = [
+ ListOption<"opMnemonics", "ops", "std::string",
+ "Operations to expand.">
+ ];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c2410c126..d37a056e8e158 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
- ExpandPatterns.cpp
+ ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
similarity index 89%
rename from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
rename to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a3055ed62..cd68039d0d964 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
return success();
}
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
- patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
- patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
- patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
- patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
- patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+ PatternRewriter &rewriter) {
+ auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+ op.getMin(), op.getFastmath());
+ rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+ op.getFastmath());
+ return success();
}
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics) {
+ auto filter = [&](StringRef name) {
+ // This should be a static assert and `consume_front` take a twine, but none
+ // is currently possible. TODO: augment `StringRef::consume_front` and make
+ // `getDialectNamespace` use `std::string_view`.
+ assert("math" == MathDialect::getDialectNamespace());
+ name.consume_front("math.");
+ return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+ };
+ if (filter(CountLeadingZerosOp::getOperationName()))
+ patterns.add(convertCtlzOp);
+ if (filter(SinhOp::getOperationName()))
+ patterns.add(convertSinhOp);
+ if (filter(CoshOp::getOperationName()))
+ patterns.add(convertCoshOp);
+ if (filter(TanOp::getOperationName()))
+ patterns.add(convertTanOp);
+ if (filter(TanhOp::getOperationName()))
+ patterns.add(convertTanhOp);
+ if (filter(AsinhOp::getOperationName()))
+ patterns.add(convertAsinhOp);
+ if (filter(AcoshOp::getOperationName()))
+ patterns.add(convertAcoshOp);
+ if (filter(AtanhOp::getOperationName()))
+ patterns.add(convertAtanhOp);
+ if (filter(FmaOp::getOperationName()))
+ patterns.add(convertFmaFOp);
+ if (filter(CeilOp::getOperationName()))
+ patterns.add(convertCeilOp);
+ if (filter(Exp2Op::getOperationName()))
+ patterns.add(convertExp2fOp);
+ if (filter(PowFOp::getOperationName()))
+ patterns.add(convertPowfOp);
+ if (filter(FPowIOp::getOperationName()))
+ patterns.add(convertFPowIOp);
+ if (filter(RoundOp::getOperationName()))
+ patterns.add(convertRoundOp);
+ if (filter(RoundEvenOp::getOperationName()))
+ patterns.add(convertRoundEvenOp);
+ if (filter(RsqrtOp::getOperationName()))
+ patterns.add(convertRsqrtOp);
+ if (filter(ClampFOp::getOperationName()))
+ patterns.add(convertClampfOp);
}
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
- patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+ : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+ using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SmallVector<StringRef> mnemonics =
+ llvm::to_vector_of<StringRef>(opMnemonics);
+ math::populateExpansionPatterns(patterns, mnemonics);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1420acaa40d35..615c607efc3c3 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER
// CHECK-LABEL: func @tanh
func.func @tanh(%arg: f32) -> f32 {
+ // CHECK-FILTER-NOT: math.tanh
%res = math.tanh %arg : f32
return %res : f32
}
@@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK-LABEL: func @vector_tanh
func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
// CHECK-NOT: math.tanh
+ // CHECK-FILTER-NOT: math.tanh
%res = math.tanh %arg : vector<4xf32>
return %res : vector<4xf32>
}
@@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
// CHECK-LABEL: func @tan
func.func @tan(%arg: f32) -> f32 {
+ // CHECK-FILTER-NOT: math.tan
%res = math.tan %arg : f32
return %res : f32
}
@@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 {
// CHECK-LABEL: func @vector_tan
func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+ // CHECK-FILTER-NOT: math.tan
%res = math.tan %arg : vector<4xf32>
return %res : vector<4xf32>
}
@@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
// -----
func.func @ctlz(%arg: i32) -> i32 {
+ // CHECK-FILTER: math.ctlz
%res = math.ctlz %arg : i32
return %res : i32
}
@@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 {
// -----
func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ // CHECK-FILTER: math.ctlz
%res = math.ctlz %arg : vector<4xi32>
return %res : vector<4xi32>
}
@@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
+ // CHECK-FILTER: math.ceil
%ret = math.ceil %a : f64
return %ret : f64
}
@@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 {
// CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
// CHECK: [[EXP:%.+]] = math.exp [[MULF]]
// CHECK: return [[EXP]]
+ // CHECK-FILTER: math.exp2
%ret = math.exp2 %a : f64
return %ret : f64
}
@@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
%a = math.rsqrt %arg : tensor<*xf32>
return %a: tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @clampf_scalar_op
+// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK: return %[[V1]] : f16
+
+func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
+ %a = math.clampf %arg to [%min, %max] : f16
+ return %a: f16
+}
+
+// CHECK-LABEL: func.func @clampf_vector_op
+// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK: return %[[V1]] : vector<3x4xf32>
+
+func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
+ %a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32>
+ return %a: vector<3x4xf32>
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 8feadedd1860e..cb10fc4397ffc 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
// CHECK-LABEL: func @atan(
@@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>)
math.isnormal %t : tensor<4x?xf32>
return
}
+
+// CHECK-LABEL: func @clampf(
+func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>,
+ %as: f32, %ms: f32, %Ms: f32,
+ %at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) {
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32>
+ %rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32>
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32
+ %rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80>
+ %rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 91e70d1785369..900dff3b5e9f1 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
TestAlgebraicSimplification.cpp
- TestExpandMath.cpp
TestPolynomialApproximation.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
deleted file mode 100644
index efc1acf8bb6cd..0000000000000
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains test passes for expanding math operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestExpandMathPass
- : public PassWrapper<TestExpandMathPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
-
- void runOnOperation() override;
- StringRef getArgument() const final { return "test-expand-math"; }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
- }
- StringRef getDescription() const final { return "Test expanding math"; }
-};
-} // namespace
-
-void TestExpandMathPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- populateExpandCtlzPattern(patterns);
- populateExpandExp2FPattern(patterns);
- populateExpandTanPattern(patterns);
- populateExpandSinhPattern(patterns);
- populateExpandCoshPattern(patterns);
- populateExpandTanhPattern(patterns);
- populateExpandAsinhPattern(patterns);
- populateExpandAcoshPattern(patterns);
- populateExpandAtanhPattern(patterns);
- populateExpandFmaFPattern(patterns);
- populateExpandCeilFPattern(patterns);
- populateExpandPowFPattern(patterns);
- populateExpandFPowIPattern(patterns);
- populateExpandRoundFPattern(patterns);
- populateExpandRoundEvenPattern(patterns);
- populateExpandRsqrtPattern(patterns);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-}
-
-namespace mlir {
-namespace test {
-void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index b599c9d8435d4..3f9d3f2125e1a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
// RUN: | mlir-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%mlir_c_runner_utils \
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 14714c452503a..7b992b4ee029b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -98,7 +98,6 @@ void registerTestDiagnosticsMetadataPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
-void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
@@ -245,7 +244,6 @@ void registerTestPasses() {
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestEmulateNarrowTypePass();
- mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
More information about the Mlir-commits
mailing list