[Mlir-commits] [mlir] [MLIR][Transforms] add eliminate-explicit-rounding pass (PR #93443)
Ivy Zhang
llvmlistbot at llvm.org
Mon May 27 00:31:12 PDT 2024
https://github.com/crazydemo updated https://github.com/llvm/llvm-project/pull/93443
>From c4dd5ad49f64f58aa46cd1d241fab0ffa5f3b553 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 9 May 2024 14:36:51 +0800
Subject: [PATCH 1/6] add canonicalize-f32-promotion pass
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 +
.../mlir/Dialect/Math/Transforms/Passes.td | 43 +++++++++++
.../Dialect/Math/Transforms/CMakeLists.txt | 1 +
.../Transforms/CanonicalizeF32Promotion.cpp | 73 +++++++++++++++++++
.../Math/canonicalize-f32-promotion.mlir | 56 ++++++++++++++
5 files changed, 174 insertions(+)
create mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
create mode 100644 mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..538dcbfbe7f77 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,47 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
+ CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..bfff17df8d7d4
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,73 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+ if (auto truncinput = innertruncop.getOperand()) {
+ auto outter_type = op.getType();
+ auto intermediate_type = innertruncop.getType();
+ auto inner_type = truncinput.getType();
+ if (outter_type.isa<ShapedType>()) {
+ outter_type = op.getType().cast<ShapedType>().getElementType();
+ intermediate_type =
+ innertruncop.getType().cast<ShapedType>().getElementType();
+ inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+ }
+ if (outter_type.isF32() &&
+ (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+ inner_type.isF32()) {
+ rewriter.replaceOp(op, {truncinput});
+ }
+ } else
+ return failure();
+ } else
+ return failure();
+ return success();
+ }
+};
+
+struct MathCanonicalizeF32Promotion final
+ : math::impl::MathCanonicalizeF32PromotionBase<
+ MathCanonicalizeF32Promotion> {
+ using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..7aad7889e2bf5
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+ %0 = math.absf %arg0 : bf16
+ %1 = math.sin %0 : bf16
+ return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to f16
+ %1 = arith.extf %0 : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ %1 = arith.extf %0 : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xf16>
+ %1 = math.sin %0 : vector<32x32x32xf16>
+ return %1 : vector<32x32x32xf16>
+}
>From 02be4d6dedc81e9e5ace44829f388e36e52e0278 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 10 May 2024 11:09:31 +0800
Subject: [PATCH 2/6] add branch case
---
.../mlir/Dialect/Math/Transforms/Passes.td | 6 +++++-
.../Transforms/CanonicalizeF32Promotion.cpp | 3 +--
.../Math/canonicalize-f32-promotion.mlir | 18 ++++++++++++++++++
3 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 538dcbfbe7f77..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,7 +44,11 @@ def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
ops.
- This pass is to eliminate the redundant truncf/extf pairs.
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
Example:
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index bfff17df8d7d4..b9b43a0887f14 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -1,5 +1,4 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
-//----------===//
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
index 7aad7889e2bf5..127eece98cf79 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -54,3 +54,21 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
%1 = math.sin %0 : vector<32x32x32xf16>
return %1 : vector<32x32x32xf16>
}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ %2 = math.cos %0 : vector<32x32x32xbf16>
+ %3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+ return %3 : vector<32x32x32xbf16>
+}
>From 07ca29dbe48d010a36fdab154687547f26a6ead5 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 17 May 2024 14:21:38 +0800
Subject: [PATCH 3/6] use single walk rather than greedy rewrite
---
.../Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index b9b43a0887f14..8257ddb5c2efc 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -64,7 +64,12 @@ struct MathCanonicalizeF32Promotion final
RewritePatternSet patterns(&getContext());
patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
- if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<arith::ExtFOp>(op))
+ ops.push_back(op);
+ });
+ if (failed(applyOpPatternsAndFold(ops, patternSet)))
signalPassFailure();
}
};
>From 5152f89609f50c3fea755391aec33a2a2bc891da Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 10:32:13 +0800
Subject: [PATCH 4/6] adjust test case
---
.../Dialect/Math/canonicalize-f32-promotion.mlir | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
index 127eece98cf79..5ed189b0033b3 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
@@ -59,12 +59,11 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
-// CHECK: [[ADDF:%.+]] = arith.addf
-// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
>From f6e310cda6e131843f519363323e60b7bbd18347 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 14:42:14 +0800
Subject: [PATCH 5/6] do cast elimination in transforms with eliminatable attr
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 -
.../mlir/Dialect/Math/Transforms/Passes.td | 47 ----------
mlir/include/mlir/Transforms/Passes.h | 5 ++
mlir/include/mlir/Transforms/Passes.td | 48 +++++++++++
.../Transforms/EmulateUnsupportedFloats.cpp | 11 ++-
.../Dialect/Math/Transforms/CMakeLists.txt | 1 -
.../Transforms/CanonicalizeF32Promotion.cpp | 77 -----------------
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 ++-
mlir/lib/Transforms/CMakeLists.txt | 3 +
.../Transforms/EliminateExplicitRounding.cpp | 85 +++++++++++++++++++
.../eliminate-explicit-rounding.mlir} | 2 +-
11 files changed, 158 insertions(+), 133 deletions(-)
delete mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
create mode 100644 mlir/lib/Transforms/EliminateExplicitRounding.cpp
rename mlir/test/{Dialect/Math/canonicalize-f32-promotion.mlir => Transforms/eliminate-explicit-rounding.mlir} (98%)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f150ff6f944d2..e2c513047c77a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,7 +17,6 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
-#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 5bf5eb45f921a..e870e714bfda5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,51 +36,4 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
-def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
- let summary = "Eliminate redundant truncf/extf pairs";
- let description = [{
- `legalize-to-f32` pass does f32 promotion for every op belonging to the
- illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
- will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
- ops.
-
- This pass is to eliminate the redundant truncf/extf pairs to improve
- performance.
-
- However, this pass may introduce numerical difference as the `f32->bf16` rounding
- is eliminated.
-
- Example:
-
- ```mlir
- // the initial func
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = math.absf %arg0 : vector<32xbf16>
- %1 = math.sin %0 : vector<32xbf16>
- return %1 : vector<32xbf16>
- }
- // after legalize-to-f32
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
- %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
- %4 = math.sin %3 : vector<32xf32>
- %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
- return %5 : vector<32xbf16>
- }
- // after canonicalize-f32-promotion
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = math.sin %1 : vector<32xf32>
- %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
- return %3 : vector<32xbf16>
- }
- ```
-
- }];
- let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
-}
-
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..c618fff9a8040 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -44,6 +44,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -137,6 +138,10 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
+/// Create eliminate-explicit-rounding pass, which eliminates the redundant
+/// truncf/extf pairs to improve performance.
+std::unique_ptr<Pass> createEliminateExplicitRoundingPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..1539bda02ac60 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -569,4 +569,52 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, these passes
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let constructor = "mlir::createEliminateExplicitRoundingPass()";
+}
+
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..9cbb3884659ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType)
- res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ if (oldType != newType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+ res = truncFOp->getResults().front();
+ }
}
rewriter.replaceOp(op, newResults);
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+ return extFOp;
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 0d39d14925d23..2a5b4fbcb5271 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
- CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
deleted file mode 100644
index 8257ddb5c2efc..0000000000000
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements removing redundant extf/truncf pairs inserted from
-// LegalizeToF32.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-
-namespace {
-
-struct CanonicalizeF32PromotionRewritePattern final
- : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const final {
- if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
- if (auto truncinput = innertruncop.getOperand()) {
- auto outter_type = op.getType();
- auto intermediate_type = innertruncop.getType();
- auto inner_type = truncinput.getType();
- if (outter_type.isa<ShapedType>()) {
- outter_type = op.getType().cast<ShapedType>().getElementType();
- intermediate_type =
- innertruncop.getType().cast<ShapedType>().getElementType();
- inner_type = truncinput.getType().cast<ShapedType>().getElementType();
- }
- if (outter_type.isF32() &&
- (intermediate_type.isF16() || intermediate_type.isBF16()) &&
- inner_type.isF32()) {
- rewriter.replaceOp(op, {truncinput});
- }
- } else
- return failure();
- } else
- return failure();
- return success();
- }
-};
-
-struct MathCanonicalizeF32Promotion final
- : math::impl::MathCanonicalizeF32PromotionBase<
- MathCanonicalizeF32Promotion> {
- using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
- FrozenRewritePatternSet patternSet(std::move(patterns));
- SmallVector<Operation *> ops;
- getOperation()->walk([&](Operation *op) {
- if (isa<arith::ExtFOp>(op))
- ops.push_back(op);
- });
- if (failed(applyOpPatternsAndFold(ops, patternSet)))
- signalPassFailure();
- }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..da049602bc909 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+ return extFOp;
});
}
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+ result = truncFOp->getResults().front();
+ }
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..131ee00fd7235 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
+ EliminateExplicitRounding.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
@@ -38,4 +39,6 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
+ MLIRArithDialect
+ MLIRMathDialect
)
diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
new file mode 100644
index 0000000000000..ae91a1ba0f24a
--- /dev/null
+++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
@@ -0,0 +1,85 @@
+//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32 and EmulateUnsupportedFloats.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+// #include "mlir/IR/Types.h"
+// #include "mlir/IR/Builders.h"
+// #include "mlir/IR/BuiltinOps.h"
+// #include "mlir/IR/Region.h"
+// #include "mlir/Pass/Pass.h"
+// #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct EliminateExplicitRoundingRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp extfop,
+ PatternRewriter &rewriter) const final {
+ // check whether the extfop is eliminatable
+ auto extfAttr = extfop->getAttrOfType<BoolAttr>("eliminatable");
+ if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure();
+
+ // check whether match `eliminatable truncf->extf` pair
+ auto truncfop = extfop.getOperand().getDefiningOp<arith::TruncFOp>();
+ if (!truncfop) return failure();
+ auto truncfAttr = truncfop->getAttrOfType<BoolAttr>("eliminatable");
+ if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure();
+
+ // check whether the the rounding pair's input and output data type are the same
+ if (auto input = truncfop.getOperand()) {
+ auto inTy = input.getType();
+ auto outTy = extfop.getType();
+ if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) {
+ rewriter.replaceOp(extfop, {input});
+ }
+ }
+ return success();
+ }
+};
+
+struct EliminateExplicitRounding final
+ : impl::EliminateExplicitRoundingBase<
+ EliminateExplicitRounding> {
+ using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<arith::ExtFOp>(op))
+ ops.push_back(op);
+ });
+ if (failed(applyOpPatternsAndFold(ops, patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createEliminateExplicitRoundingPass() {
+ return std::make_unique<EliminateExplicitRounding>();
+}
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
similarity index 98%
rename from mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
rename to mlir/test/Transforms/eliminate-explicit-rounding.mlir
index 5ed189b0033b3..2f7765a8fe270 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
>From cbc176acfbe9b27661b6031609cc39f7392e52ab Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 15:24:32 +0800
Subject: [PATCH 6/6] fix wording
---
mlir/include/mlir/Transforms/Passes.td | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1539bda02ac60..a99eca2a993cb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -570,14 +570,14 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
}
def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
- let summary = "Eliminate redundant truncf/extf pairs";
+ let summary = "Eliminate the intermidiate truncf/extf pairs";
let description = [{
- `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the
- illegal op list. Once there are some consecutive illegal ops, these passes
- will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
- ops.
+ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
+ for every op belonging to the illegal op list. Once there are some consecutive
+ illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
+ between the illegal ops.
- This pass is to eliminate the redundant truncf/extf pairs to improve
+ This pass is to eliminate the intermidiate truncf/extf pairs to improve
performance.
However, this pass may introduce numerical difference as the `f32->bf16` rounding
@@ -602,7 +602,7 @@ def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
%5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
return %5 : vector<32xbf16>
}
- // after canonicalize-f32-promotion
+ // after eliminate-explicit-rounding
func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
%0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
%1 = math.absf %0 : vector<32xf32>
More information about the Mlir-commits
mailing list