[Mlir-commits] [mlir] [MLIR][Transforms] add eliminate-explicit-rounding pass (PR #93443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 27 00:16:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Ivy Zhang (crazydemo)

<details>
<summary>Changes</summary>

This PR is to add an `eliminate-explicit-rounding pass` in `Transforms`. This pass eliminates the redundant truncf/extf pairs to improve performance. However, this pass may introduce numerical difference as the f32->bf16 rounding is eliminated.

In addition, an `eliminatable` attr is added in `LegalizeToF32` and `EmulateUnsupportedFloats` pass, and is set as true by default. `eliminate-explicit-rounding pass` will only eliminate rounding pairs with `eliminatable = true`, which means
1. the user defined `trunc / extf` op pairs will not be removed unless they are labeled with `eliminatable = true`, 
2. `trunc / extf` op pairs created by type conversion will be removed by default.




---
Full diff: https://github.com/llvm/llvm-project/pull/93443.diff


7 Files Affected:

- (modified) mlir/include/mlir/Transforms/Passes.h (+5) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+48) 
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+8-3) 
- (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+8-3) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+3) 
- (added) mlir/lib/Transforms/EliminateExplicitRounding.cpp (+85) 
- (added) mlir/test/Transforms/eliminate-explicit-rounding.mlir (+73) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..c618fff9a8040 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -44,6 +44,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -137,6 +138,10 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     int maxIterations = 10);
 
+/// Create eliminate-explicit-rounding pass, which eliminates the redundant
+/// truncf/extf pairs to improve performance.
+std::unique_ptr<Pass> createEliminateExplicitRoundingPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..1539bda02ac60 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -569,4 +569,52 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
+  let summary = "Eliminate redundant truncf/extf pairs";
+  let description = [{
+    `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the
+    illegal op list. Once there are some consecutive illegal ops, these passes
+    will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+    ops.
+    
+    This pass is to eliminate the redundant truncf/extf pairs to improve
+    performance.
+
+    However, this pass may introduce numerical difference as the `f32->bf16` rounding
+    is eliminated.
+
+    Example:
+
+    ```mlir
+    // the initial func
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = math.absf %arg0 : vector<32xbf16>
+        %1 = math.sin %0 : vector<32xbf16>
+        return %1 : vector<32xbf16>
+      }
+    // after legalize-to-f32
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+        %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+        %4 = math.sin %3 : vector<32xf32>
+        %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+        return %5 : vector<32xbf16>
+      }
+    // after canonicalize-f32-promotion
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = math.sin %1 : vector<32xf32>
+        %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+        return %3 : vector<32xbf16>
+      }
+    ```
+
+  }];
+  let constructor = "mlir::createEliminateExplicitRoundingPass()";
+}
+
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..9cbb3884659ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
   SmallVector<Value> newResults(expandedOp->getResults());
   for (auto [res, oldType, newType] : llvm::zip_equal(
            MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
-    if (oldType != newType)
-      res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+    if (oldType != newType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+      truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+      res = truncFOp->getResults().front();
+    }
   }
   rewriter.replaceOp(op, newResults);
 }
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
   });
   converter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        return b.create<arith::ExtFOp>(loc, target, input);
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+        return extFOp;
       });
 }
 
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..da049602bc909 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
   });
   typeConverter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        return b.create<arith::ExtFOp>(loc, target, input);
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+        return extFOp;
       });
 }
 
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
   SmallVector<Value> results = (*legalized)->getResults();
   for (auto [result, newType, origType] : llvm::zip_equal(
            results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType)
-      result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+    if (newType != origType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+      truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+      result = truncFOp->getResults().front();
+    }
   }
   rewriter.replaceOp(op, results);
   return success();
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..131ee00fd7235 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_library(MLIRTransforms
   SymbolPrivatize.cpp
   TopologicalSort.cpp
   ViewOpGraph.cpp
+  EliminateExplicitRounding.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
@@ -38,4 +39,6 @@ add_mlir_library(MLIRTransforms
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
+  MLIRArithDialect
+  MLIRMathDialect
   )
diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
new file mode 100644
index 0000000000000..ae91a1ba0f24a
--- /dev/null
+++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
@@ -0,0 +1,85 @@
+//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32 and EmulateUnsupportedFloats.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+// #include "mlir/IR/Types.h"
+// #include "mlir/IR/Builders.h"
+// #include "mlir/IR/BuiltinOps.h"
+// #include "mlir/IR/Region.h"
+// #include "mlir/Pass/Pass.h"
+// #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct EliminateExplicitRoundingRewritePattern final
+    : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp extfop,
+                                PatternRewriter &rewriter) const final {
+    // check whether the extfop is eliminatable
+    auto extfAttr = extfop->getAttrOfType<BoolAttr>("eliminatable");
+    if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure();
+
+    // check whether match `eliminatable truncf->extf` pair
+    auto truncfop = extfop.getOperand().getDefiningOp<arith::TruncFOp>();
+    if (!truncfop) return failure();
+    auto truncfAttr = truncfop->getAttrOfType<BoolAttr>("eliminatable");
+    if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure();
+
+    // check whether the the rounding pair's input and output data type are the same
+    if (auto input = truncfop.getOperand()) {
+        auto inTy = input.getType();
+        auto outTy = extfop.getType();
+        if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) {
+            rewriter.replaceOp(extfop, {input});
+        }
+    }
+    return success();
+  }
+};
+
+struct EliminateExplicitRounding final
+    : impl::EliminateExplicitRoundingBase<
+          EliminateExplicitRounding> {
+  using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase;
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+    SmallVector<Operation *> ops;
+    getOperation()->walk([&](Operation *op) {
+      if (isa<arith::ExtFOp>(op))
+        ops.push_back(op);
+    });
+    if (failed(applyOpPatternsAndFold(ops, patternSet)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createEliminateExplicitRoundingPass() {
+  return std::make_unique<EliminateExplicitRounding>();
+}
diff --git a/mlir/test/Transforms/eliminate-explicit-rounding.mlir b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
new file mode 100644
index 0000000000000..2f7765a8fe270
--- /dev/null
+++ b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+  %0 = math.absf %arg0 : bf16
+  %1 = math.sin %0 : bf16
+  return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to f16
+  %1 = arith.extf %0 : f16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to bf16
+  %1 = arith.extf %0 : bf16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+  %1 = math.sin %0 : vector<32x32x32xbf16>
+  return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xf16>
+  %1 = math.sin %0 : vector<32x32x32xf16>
+  return %1 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+	%1 = math.sin %0 : vector<32x32x32xbf16>
+	%2 = math.cos %0 : vector<32x32x32xbf16>
+	%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+  return %3 : vector<32x32x32xbf16>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/93443


More information about the Mlir-commits mailing list