[Mlir-commits] [mlir] [mlir][math] Add `clampf` and clean math `ExpandOps` (PR #151153)

Fabian Mora llvmlistbot at llvm.org
Tue Jul 29 06:54:11 PDT 2025


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/151153

This patch adds the `clampf` operation to the math dialect. The semantics op are defined as:
```
clampf(x, min_v, max_v) = max(min(x, min_v), max_v) 
```

The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, [__saturatef](https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__SINGLE.html#group__cuda__math__intrinsic__single_1ga2c84f08e0db7117a14509d21c3aec04e) in NVIDIA GPUs, or `__builtin_amdgcn_fmed3f` in AMD GPUs.

This patch also removes `test-expand-math` in favor of `math-expand-ops`.
Finally, it removes individual expansion population API calls like `populateExpandCoshPattern` in favor of:
```C++
void populateExpansionPatterns(RewritePatternSet &patterns,
                               ArrayRef<StringRef> opMnemonics = {});
```

>From 7191a5908889d3785ded0c69c7d6aea5a7696291 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 29 Jul 2025 13:44:01 +0000
Subject: [PATCH] add clampf

---
 mlir/include/mlir/Dialect/Math/IR/MathOps.td  |  31 ++++
 .../mlir/Dialect/Math/Transforms/Passes.h     |  26 ++--
 .../mlir/Dialect/Math/Transforms/Passes.td    |  20 +++
 .../Dialect/Math/Transforms/CMakeLists.txt    |   2 +-
 .../{ExpandPatterns.cpp => ExpandOps.cpp}     | 139 ++++++++++--------
 mlir/test/Dialect/Math/expand-math.mlir       |  35 ++++-
 mlir/test/Dialect/Math/ops.mlir               |  15 +-
 mlir/test/lib/Dialect/Math/CMakeLists.txt     |   1 -
 mlir/test/lib/Dialect/Math/TestExpandMath.cpp |  62 --------
 .../mlir-runner/test-expand-math-approx.mlir  |   2 +-
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 -
 11 files changed, 188 insertions(+), 147 deletions(-)
 rename mlir/lib/Dialect/Math/Transforms/{ExpandPatterns.cpp => ExpandOps.cpp} (89%)
 delete mode 100644 mlir/test/lib/Dialect/Math/TestExpandMath.cpp

diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 56370388dea87..cfd8c4b8f11f7 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ClampFOp
+//===----------------------------------------------------------------------===//
+
+def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
+  let summary = "floating point clamping operation";
+  let description = [{
+    The `clampf` operation takes three operands and returns one result, each of
+    these is required to be the same type. Operands must be of floating point type
+    (i.e., scalar, tensor or vector).
+
+    The semantics of the operation are described by:
+    ```
+      clampf(value, min, max) = maxf(minf(value, min), max)
+    ```
+
+    Example:
+
+    ```mlir
+    %d = math.clampf %value to [%min, %max] : f64
+    ```
+  }];
+  let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
+  let assemblyFormat = [{
+    $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)?
+    attr-dict `:` type($result)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // CopySignOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index c0fe5d3be448a..b3abbf728a3c6 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -23,22 +23,16 @@ class ConversionTarget;
 class RewritePatternSet;
 class TypeConverter;
 
-void populateExpandCtlzPattern(RewritePatternSet &patterns);
-void populateExpandTanPattern(RewritePatternSet &patterns);
-void populateExpandSinhPattern(RewritePatternSet &patterns);
-void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
-void populateExpandAsinhPattern(RewritePatternSet &patterns);
-void populateExpandAcoshPattern(RewritePatternSet &patterns);
-void populateExpandAtanhPattern(RewritePatternSet &patterns);
-void populateExpandFmaFPattern(RewritePatternSet &patterns);
-void populateExpandCeilFPattern(RewritePatternSet &patterns);
-void populateExpandExp2FPattern(RewritePatternSet &patterns);
-void populateExpandPowFPattern(RewritePatternSet &patterns);
-void populateExpandFPowIPattern(RewritePatternSet &patterns);
-void populateExpandRoundFPattern(RewritePatternSet &patterns);
-void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
-void populateExpandRsqrtPattern(RewritePatternSet &patterns);
+namespace math {
+/// Adds patterns to expand math operations into other more fundamental
+/// operations. For example, hyperbolic functions are expanded into expressions
+/// using `exp`. If `opMnemonics` is empty then all available patterns will be
+/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
+/// will be added to the set.
+void populateExpansionPatterns(RewritePatternSet &patterns,
+                               ArrayRef<StringRef> opMnemonics = {});
+} // namespace math
+
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index a84c89020d4f3..4d415aeac8f58 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
+def MathExpandOpsPass : Pass<"math-expand-ops"> {
+  let summary = "Expand math operations.";
+  let description = [{
+    Expands some math operations into more fundamental operations, allowing them
+    to be subsequently lowered through these. For example, hyperbolic functions
+    are transformed into their expanded form containing only `exp` functions.
+
+    The `ops` parameter can be used to apply only a subset of all the
+    available expansions, these must correspond to the operation mnemonic.
+    For example, `ops=sinh,acosh` will expand only `math.sinh` and
+    `math.acosh` operations. If the list is empty, then all expansions are
+    applied.
+  }];
+  let dependentDialects = ["arith::ArithDialect"];
+  let options = [
+    ListOption<"opMnemonics", "ops", "std::string",
+               "Operations to expand.">
+  ];
+}
+
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c2410c126..d37a056e8e158 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
-  ExpandPatterns.cpp
+  ExpandOps.cpp
   ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
similarity index 89%
rename from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
rename to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a3055ed62..cd68039d0d964 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
 /// Create a float constant.
 static Value createFloatConst(Location loc, Type type, APFloat value,
                               OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
   return success();
 }
 
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
-  patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
-  patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
-  patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+                                     PatternRewriter &rewriter) {
+  auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+                                         op.getMin(), op.getFastmath());
+  rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+                                                 op.getFastmath());
+  return success();
 }
 
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+                                           ArrayRef<StringRef> opMnemonics) {
+  auto filter = [&](StringRef name) {
+    // This should be a static assert and `consume_front` take a twine, but none
+    // is currently possible. TODO: augment `StringRef::consume_front` and make
+    // `getDialectNamespace` use `std::string_view`.
+    assert("math" == MathDialect::getDialectNamespace());
+    name.consume_front("math.");
+    return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+  };
+  if (filter(CountLeadingZerosOp::getOperationName()))
+    patterns.add(convertCtlzOp);
+  if (filter(SinhOp::getOperationName()))
+    patterns.add(convertSinhOp);
+  if (filter(CoshOp::getOperationName()))
+    patterns.add(convertCoshOp);
+  if (filter(TanOp::getOperationName()))
+    patterns.add(convertTanOp);
+  if (filter(TanhOp::getOperationName()))
+    patterns.add(convertTanhOp);
+  if (filter(AsinhOp::getOperationName()))
+    patterns.add(convertAsinhOp);
+  if (filter(AcoshOp::getOperationName()))
+    patterns.add(convertAcoshOp);
+  if (filter(AtanhOp::getOperationName()))
+    patterns.add(convertAtanhOp);
+  if (filter(FmaOp::getOperationName()))
+    patterns.add(convertFmaFOp);
+  if (filter(CeilOp::getOperationName()))
+    patterns.add(convertCeilOp);
+  if (filter(Exp2Op::getOperationName()))
+    patterns.add(convertExp2fOp);
+  if (filter(PowFOp::getOperationName()))
+    patterns.add(convertPowfOp);
+  if (filter(FPowIOp::getOperationName()))
+    patterns.add(convertFPowIOp);
+  if (filter(RoundOp::getOperationName()))
+    patterns.add(convertRoundOp);
+  if (filter(RoundEvenOp::getOperationName()))
+    patterns.add(convertRoundEvenOp);
+  if (filter(RsqrtOp::getOperationName()))
+    patterns.add(convertRsqrtOp);
+  if (filter(ClampFOp::getOperationName()))
+    patterns.add(convertClampfOp);
 }
 
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+    : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+  using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SmallVector<StringRef> mnemonics =
+        llvm::to_vector_of<StringRef>(opMnemonics);
+    math::populateExpansionPatterns(patterns, mnemonics);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1420acaa40d35..615c607efc3c3 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER
 
 // CHECK-LABEL: func @tanh
 func.func @tanh(%arg: f32) -> f32 {
+  // CHECK-FILTER-NOT: math.tanh
   %res = math.tanh %arg : f32
   return %res : f32
 }
@@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK-LABEL: func @vector_tanh
 func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
   // CHECK-NOT: math.tanh
+  // CHECK-FILTER-NOT: math.tanh
   %res = math.tanh %arg : vector<4xf32>
   return %res : vector<4xf32>
 }
@@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
 
 // CHECK-LABEL: func @tan
 func.func @tan(%arg: f32) -> f32 {
+  // CHECK-FILTER-NOT: math.tan
   %res = math.tan %arg : f32
   return %res : f32
 }
@@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 {
 
 // CHECK-LABEL: func @vector_tan
 func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+  // CHECK-FILTER-NOT: math.tan
   %res = math.tan %arg : vector<4xf32>
   return %res : vector<4xf32>
 }
@@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
 // -----
 
 func.func @ctlz(%arg: i32) -> i32 {
+  // CHECK-FILTER: math.ctlz
   %res = math.ctlz %arg : i32
   return %res : i32
 }
@@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 {
 // -----
 
 func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  // CHECK-FILTER: math.ctlz
   %res = math.ctlz %arg : vector<4xi32>
   return %res : vector<4xi32>
 }
@@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 {
   // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
   // CHECK-NEXT:   return [[ADDF]]
+  // CHECK-FILTER: math.ceil
   %ret = math.ceil %a : f64
   return %ret : f64
 }
@@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 {
   // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
   // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
   // CHECK:         return [[EXP]]
+  // CHECK-FILTER: math.exp2
   %ret = math.exp2 %a : f64
   return %ret : f64
 }
@@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
   %a = math.rsqrt %arg : tensor<*xf32>
   return %a: tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @clampf_scalar_op
+// CHECK-SAME:     (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK:          return %[[V1]] : f16
+
+func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
+  %a = math.clampf %arg to [%min, %max] : f16
+  return %a: f16
+}
+
+// CHECK-LABEL:    func.func @clampf_vector_op
+// CHECK-SAME:     (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          return %[[V1]] : vector<3x4xf32>
+
+func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
+  %a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32>
+  return %a: vector<3x4xf32>
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 8feadedd1860e..cb10fc4397ffc 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
 // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: func @atan(
@@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>)
   math.isnormal %t : tensor<4x?xf32>
   return
 }
+
+// CHECK-LABEL: func @clampf(
+func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>,
+                  %as: f32, %ms: f32, %Ms: f32,
+                  %at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) {
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32>
+  %rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32>
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32
+  %rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80>
+  %rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80>
+  return 
+}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 91e70d1785369..900dff3b5e9f1 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMathTestPasses
   TestAlgebraicSimplification.cpp
-  TestExpandMath.cpp
   TestPolynomialApproximation.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
deleted file mode 100644
index efc1acf8bb6cd..0000000000000
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains test passes for expanding math operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestExpandMathPass
-    : public PassWrapper<TestExpandMathPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
-
-  void runOnOperation() override;
-  StringRef getArgument() const final { return "test-expand-math"; }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
-  }
-  StringRef getDescription() const final { return "Test expanding math"; }
-};
-} // namespace
-
-void TestExpandMathPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  populateExpandCtlzPattern(patterns);
-  populateExpandExp2FPattern(patterns);
-  populateExpandTanPattern(patterns);
-  populateExpandSinhPattern(patterns);
-  populateExpandCoshPattern(patterns);
-  populateExpandTanhPattern(patterns);
-  populateExpandAsinhPattern(patterns);
-  populateExpandAcoshPattern(patterns);
-  populateExpandAtanhPattern(patterns);
-  populateExpandFmaFPattern(patterns);
-  populateExpandCeilFPattern(patterns);
-  populateExpandPowFPattern(patterns);
-  populateExpandFPowIPattern(patterns);
-  populateExpandRoundFPattern(patterns);
-  populateExpandRoundEvenPattern(patterns);
-  populateExpandRsqrtPattern(patterns);
-  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-}
-
-namespace mlir {
-namespace test {
-void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index b599c9d8435d4..3f9d3f2125e1a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -1,4 +1,4 @@
-// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
 // RUN: | mlir-runner                                                      \
 // RUN:     -e main -entry-point-result=void -O0                               \
 // RUN:     -shared-libs=%mlir_c_runner_utils  \
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 14714c452503a..7b992b4ee029b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -98,7 +98,6 @@ void registerTestDiagnosticsMetadataPass();
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
 void registerTestEmulateNarrowTypePass();
-void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
@@ -245,7 +244,6 @@ void registerTestPasses() {
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
   mlir::test::registerTestEmulateNarrowTypePass();
-  mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();



More information about the Mlir-commits mailing list