[Mlir-commits] [mlir] [MLIR][Transforms] add eliminate-explicit-rounding pass (PR #93443)
Ivy Zhang
llvmlistbot at llvm.org
Fri Jun 7 11:08:32 PDT 2024
https://github.com/crazydemo updated https://github.com/llvm/llvm-project/pull/93443
>From c4dd5ad49f64f58aa46cd1d241fab0ffa5f3b553 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 9 May 2024 14:36:51 +0800
Subject: [PATCH 01/22] add canonicalize-f32-promotion pass
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 +
.../mlir/Dialect/Math/Transforms/Passes.td | 43 +++++++++++
.../Dialect/Math/Transforms/CMakeLists.txt | 1 +
.../Transforms/CanonicalizeF32Promotion.cpp | 73 +++++++++++++++++++
.../Math/canonicalize-f32-promotion.mlir | 56 ++++++++++++++
5 files changed, 174 insertions(+)
create mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
create mode 100644 mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..538dcbfbe7f77 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,47 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
+ CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..bfff17df8d7d4
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,73 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+ if (auto truncinput = innertruncop.getOperand()) {
+ auto outter_type = op.getType();
+ auto intermediate_type = innertruncop.getType();
+ auto inner_type = truncinput.getType();
+ if (outter_type.isa<ShapedType>()) {
+ outter_type = op.getType().cast<ShapedType>().getElementType();
+ intermediate_type =
+ innertruncop.getType().cast<ShapedType>().getElementType();
+ inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+ }
+ if (outter_type.isF32() &&
+ (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+ inner_type.isF32()) {
+ rewriter.replaceOp(op, {truncinput});
+ }
+ } else
+ return failure();
+ } else
+ return failure();
+ return success();
+ }
+};
+
+struct MathCanonicalizeF32Promotion final
+ : math::impl::MathCanonicalizeF32PromotionBase<
+ MathCanonicalizeF32Promotion> {
+ using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..7aad7889e2bf5
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+ %0 = math.absf %arg0 : bf16
+ %1 = math.sin %0 : bf16
+ return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to f16
+ %1 = arith.extf %0 : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ %1 = arith.extf %0 : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xf16>
+ %1 = math.sin %0 : vector<32x32x32xf16>
+ return %1 : vector<32x32x32xf16>
+}
>From 02be4d6dedc81e9e5ace44829f388e36e52e0278 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 10 May 2024 11:09:31 +0800
Subject: [PATCH 02/22] add branch case
---
.../mlir/Dialect/Math/Transforms/Passes.td | 6 +++++-
.../Transforms/CanonicalizeF32Promotion.cpp | 3 +--
.../Math/canonicalize-f32-promotion.mlir | 18 ++++++++++++++++++
3 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 538dcbfbe7f77..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,7 +44,11 @@ def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
ops.
- This pass is to eliminate the redundant truncf/extf pairs.
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
Example:
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index bfff17df8d7d4..b9b43a0887f14 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -1,5 +1,4 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
-//----------===//
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
index 7aad7889e2bf5..127eece98cf79 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -54,3 +54,21 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
%1 = math.sin %0 : vector<32x32x32xf16>
return %1 : vector<32x32x32xf16>
}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ %2 = math.cos %0 : vector<32x32x32xbf16>
+ %3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+ return %3 : vector<32x32x32xbf16>
+}
>From 07ca29dbe48d010a36fdab154687547f26a6ead5 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 17 May 2024 14:21:38 +0800
Subject: [PATCH 03/22] use single walk rather than greedy rewrite
---
.../Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index b9b43a0887f14..8257ddb5c2efc 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -64,7 +64,12 @@ struct MathCanonicalizeF32Promotion final
RewritePatternSet patterns(&getContext());
patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
- if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<arith::ExtFOp>(op))
+ ops.push_back(op);
+ });
+ if (failed(applyOpPatternsAndFold(ops, patternSet)))
signalPassFailure();
}
};
>From 5152f89609f50c3fea755391aec33a2a2bc891da Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 10:32:13 +0800
Subject: [PATCH 04/22] adjust test case
---
.../Dialect/Math/canonicalize-f32-promotion.mlir | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
index 127eece98cf79..5ed189b0033b3 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
@@ -59,12 +59,11 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
-// CHECK: [[ADDF:%.+]] = arith.addf
-// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
%0 = math.absf %arg0 : vector<32x32x32xbf16>
%1 = math.sin %0 : vector<32x32x32xbf16>
>From f6e310cda6e131843f519363323e60b7bbd18347 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 14:42:14 +0800
Subject: [PATCH 05/22] do cast elimination in transforms with eliminatable
attr
---
.../mlir/Dialect/Math/Transforms/Passes.h | 1 -
.../mlir/Dialect/Math/Transforms/Passes.td | 47 ----------
mlir/include/mlir/Transforms/Passes.h | 5 ++
mlir/include/mlir/Transforms/Passes.td | 48 +++++++++++
.../Transforms/EmulateUnsupportedFloats.cpp | 11 ++-
.../Dialect/Math/Transforms/CMakeLists.txt | 1 -
.../Transforms/CanonicalizeF32Promotion.cpp | 77 -----------------
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 ++-
mlir/lib/Transforms/CMakeLists.txt | 3 +
.../Transforms/EliminateExplicitRounding.cpp | 85 +++++++++++++++++++
.../eliminate-explicit-rounding.mlir} | 2 +-
11 files changed, 158 insertions(+), 133 deletions(-)
delete mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
create mode 100644 mlir/lib/Transforms/EliminateExplicitRounding.cpp
rename mlir/test/{Dialect/Math/canonicalize-f32-promotion.mlir => Transforms/eliminate-explicit-rounding.mlir} (98%)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f150ff6f944d2..e2c513047c77a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,7 +17,6 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
-#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 5bf5eb45f921a..e870e714bfda5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,51 +36,4 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
-def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
- let summary = "Eliminate redundant truncf/extf pairs";
- let description = [{
- `legalize-to-f32` pass does f32 promotion for every op belonging to the
- illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
- will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
- ops.
-
- This pass is to eliminate the redundant truncf/extf pairs to improve
- performance.
-
- However, this pass may introduce numerical difference as the `f32->bf16` rounding
- is eliminated.
-
- Example:
-
- ```mlir
- // the initial func
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = math.absf %arg0 : vector<32xbf16>
- %1 = math.sin %0 : vector<32xbf16>
- return %1 : vector<32xbf16>
- }
- // after legalize-to-f32
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
- %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
- %4 = math.sin %3 : vector<32xf32>
- %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
- return %5 : vector<32xbf16>
- }
- // after canonicalize-f32-promotion
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = math.sin %1 : vector<32xf32>
- %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
- return %3 : vector<32xbf16>
- }
- ```
-
- }];
- let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
-}
-
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..c618fff9a8040 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -44,6 +44,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -137,6 +138,10 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
+/// Create eliminate-explicit-rounding pass, which eliminates the redundant
+/// truncf/extf pairs to improve performance.
+std::unique_ptr<Pass> createEliminateExplicitRoundingPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..1539bda02ac60 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -569,4 +569,52 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, these passes
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let constructor = "mlir::createEliminateExplicitRoundingPass()";
+}
+
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..9cbb3884659ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType)
- res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ if (oldType != newType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+ res = truncFOp->getResults().front();
+ }
}
rewriter.replaceOp(op, newResults);
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+ return extFOp;
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 0d39d14925d23..2a5b4fbcb5271 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
- CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
deleted file mode 100644
index 8257ddb5c2efc..0000000000000
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements removing redundant extf/truncf pairs inserted from
-// LegalizeToF32.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-
-namespace {
-
-struct CanonicalizeF32PromotionRewritePattern final
- : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(arith::ExtFOp op,
- PatternRewriter &rewriter) const final {
- if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
- if (auto truncinput = innertruncop.getOperand()) {
- auto outter_type = op.getType();
- auto intermediate_type = innertruncop.getType();
- auto inner_type = truncinput.getType();
- if (outter_type.isa<ShapedType>()) {
- outter_type = op.getType().cast<ShapedType>().getElementType();
- intermediate_type =
- innertruncop.getType().cast<ShapedType>().getElementType();
- inner_type = truncinput.getType().cast<ShapedType>().getElementType();
- }
- if (outter_type.isF32() &&
- (intermediate_type.isF16() || intermediate_type.isBF16()) &&
- inner_type.isF32()) {
- rewriter.replaceOp(op, {truncinput});
- }
- } else
- return failure();
- } else
- return failure();
- return success();
- }
-};
-
-struct MathCanonicalizeF32Promotion final
- : math::impl::MathCanonicalizeF32PromotionBase<
- MathCanonicalizeF32Promotion> {
- using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
- FrozenRewritePatternSet patternSet(std::move(patterns));
- SmallVector<Operation *> ops;
- getOperation()->walk([&](Operation *op) {
- if (isa<arith::ExtFOp>(op))
- ops.push_back(op);
- });
- if (failed(applyOpPatternsAndFold(ops, patternSet)))
- signalPassFailure();
- }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..da049602bc909 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp->setAttr("eliminatable", b.getBoolAttr(true));
+ return extFOp;
});
}
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
+ result = truncFOp->getResults().front();
+ }
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..131ee00fd7235 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
+ EliminateExplicitRounding.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
@@ -38,4 +39,6 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
+ MLIRArithDialect
+ MLIRMathDialect
)
diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
new file mode 100644
index 0000000000000..ae91a1ba0f24a
--- /dev/null
+++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
@@ -0,0 +1,85 @@
+//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32 and EmulateUnsupportedFloats.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+// #include "mlir/IR/Types.h"
+// #include "mlir/IR/Builders.h"
+// #include "mlir/IR/BuiltinOps.h"
+// #include "mlir/IR/Region.h"
+// #include "mlir/Pass/Pass.h"
+// #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct EliminateExplicitRoundingRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp extfop,
+ PatternRewriter &rewriter) const final {
+ // check whether the extfop is eliminatable
+ auto extfAttr = extfop->getAttrOfType<BoolAttr>("eliminatable");
+ if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure();
+
+ // check whether match `eliminatable truncf->extf` pair
+ auto truncfop = extfop.getOperand().getDefiningOp<arith::TruncFOp>();
+ if (!truncfop) return failure();
+ auto truncfAttr = truncfop->getAttrOfType<BoolAttr>("eliminatable");
+ if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure();
+
+ // check whether the the rounding pair's input and output data type are the same
+ if (auto input = truncfop.getOperand()) {
+ auto inTy = input.getType();
+ auto outTy = extfop.getType();
+ if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) {
+ rewriter.replaceOp(extfop, {input});
+ }
+ }
+ return success();
+ }
+};
+
+struct EliminateExplicitRounding final
+ : impl::EliminateExplicitRoundingBase<
+ EliminateExplicitRounding> {
+ using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<arith::ExtFOp>(op))
+ ops.push_back(op);
+ });
+ if (failed(applyOpPatternsAndFold(ops, patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createEliminateExplicitRoundingPass() {
+ return std::make_unique<EliminateExplicitRounding>();
+}
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
similarity index 98%
rename from mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
rename to mlir/test/Transforms/eliminate-explicit-rounding.mlir
index 5ed189b0033b3..2f7765a8fe270 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Transforms/eliminate-explicit-rounding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
>From cbc176acfbe9b27661b6031609cc39f7392e52ab Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 15:24:32 +0800
Subject: [PATCH 06/22] fix wording
---
mlir/include/mlir/Transforms/Passes.td | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1539bda02ac60..a99eca2a993cb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -570,14 +570,14 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
}
def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
- let summary = "Eliminate redundant truncf/extf pairs";
+ let summary = "Eliminate the intermidiate truncf/extf pairs";
let description = [{
- `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the
- illegal op list. Once there are some consecutive illegal ops, these passes
- will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
- ops.
+ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
+ for every op belonging to the illegal op list. Once there are some consecutive
+ illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
+ between the illegal ops.
- This pass is to eliminate the redundant truncf/extf pairs to improve
+ This pass is to eliminate the intermidiate truncf/extf pairs to improve
performance.
However, this pass may introduce numerical difference as the `f32->bf16` rounding
@@ -602,7 +602,7 @@ def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
%5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
return %5 : vector<32xbf16>
}
- // after canonicalize-f32-promotion
+ // after eliminate-explicit-rounding
func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
%0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
%1 = math.absf %0 : vector<32xf32>
>From 2dcb687d5f95e88fe2380340ce63de225f21e175 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 15:38:09 +0800
Subject: [PATCH 07/22] fix test
---
.../Transforms/EliminateExplicitRounding.cpp | 16 ++++------
.../Arith/emulate-unsupported-floats.mlir | 32 +++++++++----------
2 files changed, 23 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
index ae91a1ba0f24a..4731b5a15f415 100644
--- a/mlir/lib/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp
@@ -16,12 +16,6 @@
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-// #include "mlir/IR/Types.h"
-// #include "mlir/IR/Builders.h"
-// #include "mlir/IR/BuiltinOps.h"
-// #include "mlir/IR/Region.h"
-// #include "mlir/Pass/Pass.h"
-// #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
@@ -48,12 +42,16 @@ struct EliminateExplicitRoundingRewritePattern final
auto truncfAttr = truncfop->getAttrOfType<BoolAttr>("eliminatable");
if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure();
- // check whether the the rounding pair's input and output data type are the same
+ // check whether the the rounding pair's input and output data type are the
+ // same Currently only consider to eliminate rounding pairs for (bf16 / f16
+ // <-> f32)
if (auto input = truncfop.getOperand()) {
auto inTy = input.getType();
auto outTy = extfop.getType();
- if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) {
- rewriter.replaceOp(extfop, {input});
+ auto shortTy = getElementTypeOrSelf(truncfop.getType());
+ if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
+ (shortTy.isF16() || shortTy.isBF16())) {
+ rewriter.replaceOp(extfop, {input});
}
}
return success();
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index a69ef131d8d47..76952297a5452 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {eliminatable = true} : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {eliminatable = true} : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] {eliminatable = true} : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] {eliminatable = true} : bf16 to f32
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] {eliminatable = true} : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] {eliminatable = true} : bf16 to f32
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] {eliminatable = true} : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] {eliminatable = true} : bf16 to f32
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
// CHECK: return [[RES]]
%p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// CHECK-LABEL: @memops
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] {eliminatable = true} : f8E4M3FNUZ to f32
// CHECK: memref.store [[V]]
// CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] {eliminatable = true} : f8E4M3FNUZ to f32
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] {eliminatable = true} : f32 to f8E4M3FNUZ
// CHECK: memref.store [[X]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -63,10 +63,10 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
-// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] {eliminatable = true} : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[RET:%.+]] = arith.extf [[B]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: return [[RET]]
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
%ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
>From 336e0eba4a671a1a51e7738f6031764023f73d55 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 16:53:27 +0800
Subject: [PATCH 08/22] move to arith dialect and add optional filter func
---
.../mlir/Dialect/Arith/Transforms/Passes.td | 46 ++++++++++++++
mlir/include/mlir/Transforms/Passes.h | 5 --
mlir/include/mlir/Transforms/Passes.td | 48 ---------------
.../Dialect/Arith/Transforms/CMakeLists.txt | 1 +
.../Transforms/EliminateExplicitRounding.cpp | 61 ++++++++++---------
.../Transforms/EmulateUnsupportedFloats.cpp | 11 +---
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 +---
mlir/lib/Transforms/CMakeLists.txt | 3 -
.../Arith/emulate-unsupported-floats.mlir | 32 +++++-----
9 files changed, 101 insertions(+), 117 deletions(-)
rename mlir/lib/{ => Dialect/Arith}/Transforms/EliminateExplicitRounding.cpp (55%)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 4096e309199e9..d0d614078619e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -117,4 +117,50 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
];
}
+ def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
+ let summary = "Eliminate the intermidiate truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
+ for every op belonging to the illegal op list. Once there are some consecutive
+ illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
+ between the illegal ops.
+
+ This pass is to eliminate the intermidiate truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after eliminate-explicit-rounding
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+}
+
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index c618fff9a8040..58bd61b2ae8b8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -44,7 +44,6 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
-#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -138,10 +137,6 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
-/// Create eliminate-explicit-rounding pass, which eliminates the redundant
-/// truncf/extf pairs to improve performance.
-std::unique_ptr<Pass> createEliminateExplicitRoundingPass();
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a99eca2a993cb..1b40a87c63f27 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -569,52 +569,4 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
-def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
- let summary = "Eliminate the intermidiate truncf/extf pairs";
- let description = [{
- `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
- for every op belonging to the illegal op list. Once there are some consecutive
- illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
- between the illegal ops.
-
- This pass is to eliminate the intermidiate truncf/extf pairs to improve
- performance.
-
- However, this pass may introduce numerical difference as the `f32->bf16` rounding
- is eliminated.
-
- Example:
-
- ```mlir
- // the initial func
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = math.absf %arg0 : vector<32xbf16>
- %1 = math.sin %0 : vector<32xbf16>
- return %1 : vector<32xbf16>
- }
- // after legalize-to-f32
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
- %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
- %4 = math.sin %3 : vector<32xf32>
- %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
- return %5 : vector<32xbf16>
- }
- // after eliminate-explicit-rounding
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = math.sin %1 : vector<32xf32>
- %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
- return %3 : vector<32xbf16>
- }
- ```
-
- }];
- let constructor = "mlir::createEliminateExplicitRoundingPass()";
-}
-
-
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 12659eaba1fa5..a12da70bae9af 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
BufferViewFlowOpInterfaceImpl.cpp
+ EliminateExplicitRounding.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
similarity index 55%
rename from mlir/lib/Transforms/EliminateExplicitRounding.cpp
rename to mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 4731b5a15f415..922531e976252 100644
--- a/mlir/lib/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -1,4 +1,5 @@
-//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===//
+//===- EliminateExplicitRounding.cpp - Remove intermediate extf/truncf pairs
+//-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,21 +7,22 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements removing redundant extf/truncf pairs inserted from
-// LegalizeToF32 and EmulateUnsupportedFloats.
+// This file implements removing intermediate extf/truncf pairs inserted from
+// type conversion.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
+namespace arith {
#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING
-#include "mlir/Transforms/Passes.h.inc"
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace arith
} // namespace mlir
using namespace mlir;
@@ -30,37 +32,42 @@ namespace {
struct EliminateExplicitRoundingRewritePattern final
: OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ using FilterFunction = std::function<bool(Operation *)>;
+
+ EliminateExplicitRoundingRewritePattern(MLIRContext *context,
+ FilterFunction filterFunc = nullptr)
+ : OpRewritePattern(context), filterFunc(filterFunc) {}
+
LogicalResult matchAndRewrite(arith::ExtFOp extfop,
PatternRewriter &rewriter) const final {
- // check whether the extfop is eliminatable
- auto extfAttr = extfop->getAttrOfType<BoolAttr>("eliminatable");
- if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure();
-
- // check whether match `eliminatable truncf->extf` pair
+ if (filterFunc && filterFunc(extfop))
+ return failure();
+ // check whether match `truncf->extf` pair
auto truncfop = extfop.getOperand().getDefiningOp<arith::TruncFOp>();
- if (!truncfop) return failure();
- auto truncfAttr = truncfop->getAttrOfType<BoolAttr>("eliminatable");
- if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure();
+ if (!truncfop)
+ return failure();
// check whether the the rounding pair's input and output data type are the
- // same Currently only consider to eliminate rounding pairs for (bf16 / f16
+ // same. Currently only consider to eliminate rounding pairs for (bf16 / f16
// <-> f32)
if (auto input = truncfop.getOperand()) {
- auto inTy = input.getType();
- auto outTy = extfop.getType();
- auto shortTy = getElementTypeOrSelf(truncfop.getType());
- if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
- (shortTy.isF16() || shortTy.isBF16())) {
- rewriter.replaceOp(extfop, {input});
- }
+ auto inTy = input.getType();
+ auto outTy = extfop.getType();
+ auto shortTy = getElementTypeOrSelf(truncfop.getType());
+ if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
+ (shortTy.isF16() || shortTy.isBF16())) {
+ rewriter.replaceOp(extfop, {input});
+ }
}
return success();
}
+
+private:
+ FilterFunction filterFunc;
};
struct EliminateExplicitRounding final
- : impl::EliminateExplicitRoundingBase<
- EliminateExplicitRounding> {
+ : arith::impl::EliminateExplicitRoundingBase<EliminateExplicitRounding> {
using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
@@ -77,7 +84,3 @@ struct EliminateExplicitRounding final
};
} // namespace
-
-std::unique_ptr<Pass> mlir::createEliminateExplicitRoundingPass() {
- return std::make_unique<EliminateExplicitRounding>();
-}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 9cbb3884659ee..4a50da3513f99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
- truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
- res = truncFOp->getResults().front();
- }
+ if (oldType != newType)
+ res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
}
rewriter.replaceOp(op, newResults);
}
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp->setAttr("eliminatable", b.getBoolAttr(true));
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index da049602bc909..5998133b7eab8 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp->setAttr("eliminatable", b.getBoolAttr(true));
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
- truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true));
- result = truncFOp->getResults().front();
- }
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 131ee00fd7235..90c0298fb5e46 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -20,7 +20,6 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
- EliminateExplicitRounding.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
@@ -39,6 +38,4 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
- MLIRArithDialect
- MLIRMathDialect
)
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 76952297a5452..a69ef131d8d47 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {eliminatable = true} : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {eliminatable = true} : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] {eliminatable = true} : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] {eliminatable = true} : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] {eliminatable = true} : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] {eliminatable = true} : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] {eliminatable = true} : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] {eliminatable = true} : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
// CHECK: return [[RES]]
%p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// CHECK-LABEL: @memops
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] {eliminatable = true} : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
// CHECK: memref.store [[V]]
// CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] {eliminatable = true} : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] {eliminatable = true} : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
// CHECK: memref.store [[X]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -63,10 +63,10 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] {eliminatable = true} : vector<4xf32> to vector<4xf8E4M3FNUZ>
-// CHECK: [[RET:%.+]] = arith.extf [[B]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: return [[RET]]
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
%ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
>From 92e809b0c2c648a8cb796111b2bc5dbe979f0d1a Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 19:05:36 +0800
Subject: [PATCH 09/22] fix comment
---
.../Transforms/EliminateExplicitRounding.cpp | 26 +++++++++++--------
1 file changed, 15 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 922531e976252..908d358857013 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -38,28 +38,32 @@ struct EliminateExplicitRoundingRewritePattern final
FilterFunction filterFunc = nullptr)
: OpRewritePattern(context), filterFunc(filterFunc) {}
- LogicalResult matchAndRewrite(arith::ExtFOp extfop,
+ LogicalResult matchAndRewrite(arith::ExtFOp extFOp,
PatternRewriter &rewriter) const final {
- if (filterFunc && filterFunc(extfop))
+ // check whether match `truncF->extF` pair
+ auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
+ if (!truncFOp)
return failure();
- // check whether match `truncf->extf` pair
- auto truncfop = extfop.getOperand().getDefiningOp<arith::TruncFOp>();
- if (!truncfop)
+
+ // check whether need to filter out
+ if (filterFunc && filterFunc(extFOp))
return failure();
- // check whether the the rounding pair's input and output data type are the
+ // check whether the rounding pair's input and output data type are the
// same. Currently only consider to eliminate rounding pairs for (bf16 / f16
// <-> f32)
- if (auto input = truncfop.getOperand()) {
+ if (auto input = truncFOp.getOperand()) {
auto inTy = input.getType();
- auto outTy = extfop.getType();
- auto shortTy = getElementTypeOrSelf(truncfop.getType());
+ auto outTy = extFOp.getType();
+ auto shortTy = getElementTypeOrSelf(truncFOp.getType());
if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
(shortTy.isF16() || shortTy.isBF16())) {
- rewriter.replaceOp(extfop, {input});
+ rewriter.replaceOp(extFOp, {input});
+ return success();
}
}
- return success();
+
+ return failure();
}
private:
>From 5583436e89a852a5141b403ea1f1ee19dbc88d8e Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 27 May 2024 20:12:41 +0800
Subject: [PATCH 10/22] remove unnecessary if
---
.../Transforms/EliminateExplicitRounding.cpp | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 908d358857013..bf510f9671c01 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -52,15 +52,14 @@ struct EliminateExplicitRoundingRewritePattern final
// check whether the rounding pair's input and output data type are the
// same. Currently only consider to eliminate rounding pairs for (bf16 / f16
// <-> f32)
- if (auto input = truncFOp.getOperand()) {
- auto inTy = input.getType();
- auto outTy = extFOp.getType();
- auto shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
- (shortTy.isF16() || shortTy.isBF16())) {
- rewriter.replaceOp(extFOp, {input});
- return success();
- }
+ auto input = truncFOp.getOperand();
+ auto inTy = input.getType();
+ auto outTy = extFOp.getType();
+ auto shortTy = getElementTypeOrSelf(truncFOp.getType());
+ if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
+ (shortTy.isF16() || shortTy.isBF16())) {
+ rewriter.replaceOp(extFOp, {input});
+ return success();
}
return failure();
>From c9e0e8bd3fd63a849865ca8517cfb63e4c8f9a81 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Wed, 29 May 2024 13:21:30 +0800
Subject: [PATCH 11/22] add test case
---
.../Arith}/eliminate-explicit-rounding.mlir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
rename mlir/test/{Transforms => Dialect/Arith}/eliminate-explicit-rounding.mlir (74%)
diff --git a/mlir/test/Transforms/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
similarity index 74%
rename from mlir/test/Transforms/eliminate-explicit-rounding.mlir
rename to mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
index 2f7765a8fe270..70f9570235b56 100644
--- a/mlir/test/Transforms/eliminate-explicit-rounding.mlir
+++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
@@ -71,3 +71,22 @@ func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xb
%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
return %3 : vector<32x32x32xbf16>
}
+
+// CHECK-LABEL: @bf16_fma
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
+func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ %2 = math.fma %1, %arg1, %arg2 : vector<32x32x32xbf16>
+ %3 = arith.addf %2, %1 : vector<32x32x32xbf16>
+ return %3 : vector<32x32x32xbf16>
+}
>From 923b4513c13c44cb739a2335f665d8b7fa3ec902 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Wed, 29 May 2024 21:13:10 +0800
Subject: [PATCH 12/22] fix comments
---
.../Transforms/EliminateExplicitRounding.cpp | 13 ++--
.../Arith/eliminate-explicit-rounding.mlir | 67 +++++++++++++------
2 files changed, 52 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index bf510f9671c01..5b2d243eac29d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -40,18 +40,18 @@ struct EliminateExplicitRoundingRewritePattern final
LogicalResult matchAndRewrite(arith::ExtFOp extFOp,
PatternRewriter &rewriter) const final {
- // check whether match `truncF->extF` pair
+ // Check whether match `truncF->extF` pair.
auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
if (!truncFOp)
return failure();
- // check whether need to filter out
+ // Check whether need to filter out.
if (filterFunc && filterFunc(extFOp))
return failure();
- // check whether the rounding pair's input and output data type are the
+ // Check whether the rounding pair's input and output data type are the
// same. Currently only consider to eliminate rounding pairs for (bf16 / f16
- // <-> f32)
+ // <-> f32).
auto input = truncFOp.getOperand();
auto inTy = input.getType();
auto outTy = extFOp.getType();
@@ -77,10 +77,7 @@ struct EliminateExplicitRounding final
patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
SmallVector<Operation *> ops;
- getOperation()->walk([&](Operation *op) {
- if (isa<arith::ExtFOp>(op))
- ops.push_back(op);
- });
+ getOperation()->walk([&](arith::ExtFOp op) { ops.push_back(op); });
if (failed(applyOpPatternsAndFold(ops, patternSet)))
signalPassFailure();
}
diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
index 70f9570235b56..55cf4fdadd922 100644
--- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
+++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --eliminate-explicit-rounding | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
@@ -8,9 +8,13 @@
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : bf16
func.func @sequences(%arg0: bf16) -> bf16 {
- %0 = math.absf %arg0 : bf16
- %1 = math.sin %0 : bf16
- return %1 : bf16
+ %0 = arith.extf %arg0 : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 : f32 to bf16
+ %3 = arith.extf %2 : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 : f32 to bf16
+ return %5 : bf16
}
// CHECK-LABEL: @eliminatecastoncastf16
@@ -37,9 +41,13 @@ func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = math.absf %arg0 : vector<32x32x32xbf16>
- %1 = math.sin %0 : vector<32x32x32xbf16>
- return %1 : vector<32x32x32xbf16>
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %5 : vector<32x32x32xbf16>
}
// CHECK-LABEL: @f16_sin_vector
@@ -50,9 +58,13 @@ func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
- %0 = math.absf %arg0 : vector<32x32x32xf16>
- %1 = math.sin %0 : vector<32x32x32xf16>
- return %1 : vector<32x32x32xf16>
+ %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16>
+ %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16>
+ return %5 : vector<32x32x32xf16>
}
// CHECK-LABEL: @bf16_branch_vector
@@ -65,11 +77,19 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = math.absf %arg0 : vector<32x32x32xbf16>
- %1 = math.sin %0 : vector<32x32x32xbf16>
- %2 = math.cos %0 : vector<32x32x32xbf16>
- %3 = arith.addf %1, %2 : vector<32x32x32xbf16>
- return %3 : vector<32x32x32xbf16>
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.cos %3 : vector<32x32x32xf32>
+ %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %10 = arith.addf %6, %9 : vector<32x32x32xf32>
+ %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %11 : vector<32x32x32xbf16>
}
// CHECK-LABEL: @bf16_fma
@@ -84,9 +104,16 @@ func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xb
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = math.absf %arg0 : vector<32x32x32xbf16>
- %1 = math.sin %0 : vector<32x32x32xbf16>
- %2 = math.fma %1, %arg1, %arg2 : vector<32x32x32xbf16>
- %3 = arith.addf %2, %1 : vector<32x32x32xbf16>
- return %3 : vector<32x32x32xbf16>
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
+ %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %9 = arith.addf %8, %6 : vector<32x32x32xf32>
+ %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %10 : vector<32x32x32xbf16>
}
>From a2c2e012f35bb49be358bd665691b4ddac1bc183 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 30 May 2024 10:10:36 +0800
Subject: [PATCH 13/22] fix typo
---
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d0d614078619e..6fc89cc91b740 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -125,7 +125,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
between the illegal ops.
- This pass is to eliminate the intermidiate truncf/extf pairs to improve
+ This pass is to eliminate the intermediate truncf/extf pairs to improve
performance.
However, this pass may introduce numerical difference as the `f32->bf16` rounding
>From 345bd9cb8b121e88d2c49f2f22b4fa7da2f312ce Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 30 May 2024 10:40:59 +0800
Subject: [PATCH 14/22] fix comment
---
.../Transforms/EliminateExplicitRounding.cpp | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 5b2d243eac29d..8a5f10a6cbbb0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -46,18 +46,20 @@ struct EliminateExplicitRoundingRewritePattern final
return failure();
// Check whether need to filter out.
- if (filterFunc && filterFunc(extFOp))
+ if (filterFunc && filterFunc(extFOp)) {
+ extFOp.emitError("Operation filtered out by filterFunc");
return failure();
+ }
// Check whether the rounding pair's input and output data type are the
// same. Currently only consider to eliminate rounding pairs for (bf16 / f16
// <-> f32).
- auto input = truncFOp.getOperand();
- auto inTy = input.getType();
- auto outTy = extFOp.getType();
- auto shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() &&
- (shortTy.isF16() || shortTy.isBF16())) {
+ Value input = truncFOp.getOperand();
+ Type inTy = getElementTypeOrSelf(input.getType());
+ Type outTy = getElementTypeOrSelf(extFOp.getType());
+ Type shortTy = getElementTypeOrSelf(truncFOp.getType());
+ if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
+ (isa<Float16Type>(shortTy) || isa<BFloat16Type>(shortTy))) {
rewriter.replaceOp(extFOp, {input});
return success();
}
>From e6fd571eb8f7dde4eee9127b681010a66282f27a Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 30 May 2024 11:55:43 +0800
Subject: [PATCH 15/22] fix
---
.../mlir/Dialect/Arith/Transforms/Passes.td | 2 +-
.../Transforms/EliminateExplicitRounding.cpp | 18 ++++++++----------
.../Arith/eliminate-explicit-rounding.mlir | 2 +-
3 files changed, 10 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 6fc89cc91b740..7afec9f752cfa 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -117,7 +117,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
];
}
- def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> {
+ def ArithEliminateExplicitRounding : Pass<"arith-eliminate-explicit-rounding"> {
let summary = "Eliminate the intermidiate truncf/extf pairs";
let description = [{
`legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 8a5f10a6cbbb0..5ab540fa0e9fb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -20,7 +20,7 @@
namespace mlir {
namespace arith {
-#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING
+#define GEN_PASS_DEF_ARITHELIMINATEEXPLICITROUNDING
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace arith
} // namespace mlir
@@ -34,10 +34,6 @@ struct EliminateExplicitRoundingRewritePattern final
using OpRewritePattern::OpRewritePattern;
using FilterFunction = std::function<bool(Operation *)>;
- EliminateExplicitRoundingRewritePattern(MLIRContext *context,
- FilterFunction filterFunc = nullptr)
- : OpRewritePattern(context), filterFunc(filterFunc) {}
-
LogicalResult matchAndRewrite(arith::ExtFOp extFOp,
PatternRewriter &rewriter) const final {
// Check whether match `truncF->extF` pair.
@@ -47,8 +43,9 @@ struct EliminateExplicitRoundingRewritePattern final
// Check whether need to filter out.
if (filterFunc && filterFunc(extFOp)) {
- extFOp.emitError("Operation filtered out by filterFunc");
- return failure();
+ return rewriter.notifyMatchFailure(extFOp, [](Diagnostic &diag) {
+ diag << "Operation filtered out by filterFunc";
+ });
}
// Check whether the rounding pair's input and output data type are the
@@ -59,7 +56,7 @@ struct EliminateExplicitRoundingRewritePattern final
Type outTy = getElementTypeOrSelf(extFOp.getType());
Type shortTy = getElementTypeOrSelf(truncFOp.getType());
if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
- (isa<Float16Type>(shortTy) || isa<BFloat16Type>(shortTy))) {
+ (isa<Float16Type, BFloat16Type>(shortTy))) {
rewriter.replaceOp(extFOp, {input});
return success();
}
@@ -72,8 +69,9 @@ struct EliminateExplicitRoundingRewritePattern final
};
struct EliminateExplicitRounding final
- : arith::impl::EliminateExplicitRoundingBase<EliminateExplicitRounding> {
- using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase;
+ : arith::impl::ArithEliminateExplicitRoundingBase<
+ EliminateExplicitRounding> {
+ using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
index 55cf4fdadd922..f2ba276a4f7bb 100644
--- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
+++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file --eliminate-explicit-rounding | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --arith-eliminate-explicit-rounding | FileCheck %s
// CHECK-LABEL: @sequences
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
>From ab80bafa7d9936e59d7050beeca3679555da7d4d Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 30 May 2024 13:40:49 +0800
Subject: [PATCH 16/22] remove filter func
---
.../Arith/Transforms/EliminateExplicitRounding.cpp | 11 -----------
1 file changed, 11 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 5ab540fa0e9fb..6b2bdd1404bd6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -32,7 +32,6 @@ namespace {
struct EliminateExplicitRoundingRewritePattern final
: OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
- using FilterFunction = std::function<bool(Operation *)>;
LogicalResult matchAndRewrite(arith::ExtFOp extFOp,
PatternRewriter &rewriter) const final {
@@ -41,13 +40,6 @@ struct EliminateExplicitRoundingRewritePattern final
if (!truncFOp)
return failure();
- // Check whether need to filter out.
- if (filterFunc && filterFunc(extFOp)) {
- return rewriter.notifyMatchFailure(extFOp, [](Diagnostic &diag) {
- diag << "Operation filtered out by filterFunc";
- });
- }
-
// Check whether the rounding pair's input and output data type are the
// same. Currently only consider to eliminate rounding pairs for (bf16 / f16
// <-> f32).
@@ -63,9 +55,6 @@ struct EliminateExplicitRoundingRewritePattern final
return failure();
}
-
-private:
- FilterFunction filterFunc;
};
struct EliminateExplicitRounding final
>From 961f6f8798b4bddde5ef83547b941ca50e95b8b1 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 3 Jun 2024 10:26:18 +0800
Subject: [PATCH 17/22] do not use pattern
---
.../Transforms/EliminateExplicitRounding.cpp | 57 +++++++------------
1 file changed, 22 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index 6b2bdd1404bd6..b341dfd40ed4f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -29,46 +29,33 @@ using namespace mlir;
namespace {
-struct EliminateExplicitRoundingRewritePattern final
- : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::ExtFOp extFOp,
- PatternRewriter &rewriter) const final {
- // Check whether match `truncF->extF` pair.
- auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
- if (!truncFOp)
- return failure();
-
- // Check whether the rounding pair's input and output data type are the
- // same. Currently only consider to eliminate rounding pairs for (bf16 / f16
- // <-> f32).
- Value input = truncFOp.getOperand();
- Type inTy = getElementTypeOrSelf(input.getType());
- Type outTy = getElementTypeOrSelf(extFOp.getType());
- Type shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
- (isa<Float16Type, BFloat16Type>(shortTy))) {
- rewriter.replaceOp(extFOp, {input});
- return success();
- }
-
- return failure();
- }
-};
-
struct EliminateExplicitRounding final
: arith::impl::ArithEliminateExplicitRoundingBase<
EliminateExplicitRounding> {
using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.insert<EliminateExplicitRoundingRewritePattern>(&getContext());
- FrozenRewritePatternSet patternSet(std::move(patterns));
- SmallVector<Operation *> ops;
- getOperation()->walk([&](arith::ExtFOp op) { ops.push_back(op); });
- if (failed(applyOpPatternsAndFold(ops, patternSet)))
- signalPassFailure();
+ getOperation()->walk([&](arith::ExtFOp extFOp) {
+ // Check whether match `truncF->extF` pair.
+ auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
+ if (truncFOp) {
+ // Check whether the rounding pair's input and output data type are the
+ // same. Currently only consider to eliminate rounding pairs for (bf16 /
+ // f16
+ // <-> f32).
+ Value input = truncFOp.getOperand();
+ Type inTy = getElementTypeOrSelf(input.getType());
+ Type outTy = getElementTypeOrSelf(extFOp.getType());
+ Type shortTy = getElementTypeOrSelf(truncFOp.getType());
+ if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
+ (isa<Float16Type, BFloat16Type>(shortTy))) {
+ for (auto &use :
+ llvm::make_early_inc_range(extFOp.getResult().getUses())) {
+ use.set(input);
+ }
+ extFOp.erase();
+ }
+ }
+ });
}
};
>From 8e6de0d46b9393fb01ea82be959ea7cdb80bcfa1 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 3 Jun 2024 14:02:26 +0800
Subject: [PATCH 18/22] fix comment
---
.../Transforms/EliminateExplicitRounding.cpp | 32 +++++++++----------
1 file changed, 15 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
index b341dfd40ed4f..7df77629127fa 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
@@ -37,23 +37,21 @@ struct EliminateExplicitRounding final
getOperation()->walk([&](arith::ExtFOp extFOp) {
// Check whether match `truncF->extF` pair.
auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
- if (truncFOp) {
- // Check whether the rounding pair's input and output data type are the
- // same. Currently only consider to eliminate rounding pairs for (bf16 /
- // f16
- // <-> f32).
- Value input = truncFOp.getOperand();
- Type inTy = getElementTypeOrSelf(input.getType());
- Type outTy = getElementTypeOrSelf(extFOp.getType());
- Type shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
- (isa<Float16Type, BFloat16Type>(shortTy))) {
- for (auto &use :
- llvm::make_early_inc_range(extFOp.getResult().getUses())) {
- use.set(input);
- }
- extFOp.erase();
- }
+ if (!truncFOp)
+ return;
+ // Check whether the rounding pair's input and output data type are the
+ // same. Currently only consider to eliminate rounding pairs for (bf16 /
+ // f16 <-> f32).
+ Value input = truncFOp.getOperand();
+ Type inTy = getElementTypeOrSelf(input.getType());
+ Type outTy = getElementTypeOrSelf(extFOp.getType());
+ Type shortTy = getElementTypeOrSelf(truncFOp.getType());
+ if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
+ (isa<Float16Type, BFloat16Type>(shortTy))) {
+ extFOp.replaceAllUsesWith(input);
+ extFOp.erase();
+ if (truncFOp.getResult().getUses().empty())
+ truncFOp.erase();
}
});
}
>From 66fef95ac3d07eebfae4b126bb6dca4359329e16 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 7 Jun 2024 00:48:47 +0800
Subject: [PATCH 19/22] add fastmath flag attrs and use canonicalizer
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 13 +-
.../mlir/Dialect/Arith/Transforms/Passes.td | 46 -------
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 33 +++++
.../Dialect/Arith/Transforms/CMakeLists.txt | 1 -
.../Transforms/EliminateExplicitRounding.cpp | 60 ---------
.../Arith/eliminate-explicit-rounding.mlir | 119 ------------------
mlir/test/Transforms/canonicalize.mlir | 118 +++++++++++++++++
7 files changed, 163 insertions(+), 227 deletions(-)
delete mode 100644 mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
delete mode 100644 mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ead52332e8eec..6fff83dc3df7f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1195,6 +1195,14 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
}];
let hasVerifier = 1;
let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
+ let arguments = (ins FloatLike:$in, DefaultValuedAttr<
+ Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::contract">:$fastmath);
+ let results = (outs FloatLike:$out);
+
+ let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
@@ -1235,7 +1243,10 @@ def Arith_TruncFOp :
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
- OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+ DefaultValuedAttr<
+ Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::contract">:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 7afec9f752cfa..4096e309199e9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -117,50 +117,4 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
];
}
- def ArithEliminateExplicitRounding : Pass<"arith-eliminate-explicit-rounding"> {
- let summary = "Eliminate the intermidiate truncf/extf pairs";
- let description = [{
- `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion
- for every op belonging to the illegal op list. Once there are some consecutive
- illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs
- between the illegal ops.
-
- This pass is to eliminate the intermediate truncf/extf pairs to improve
- performance.
-
- However, this pass may introduce numerical difference as the `f32->bf16` rounding
- is eliminated.
-
- Example:
-
- ```mlir
- // the initial func
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = math.absf %arg0 : vector<32xbf16>
- %1 = math.sin %0 : vector<32xbf16>
- return %1 : vector<32xbf16>
- }
- // after legalize-to-f32
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
- %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
- %4 = math.sin %3 : vector<32xf32>
- %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
- return %5 : vector<32xbf16>
- }
- // after eliminate-explicit-rounding
- func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
- %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
- %1 = math.absf %0 : vector<32xf32>
- %2 = math.sin %1 : vector<32xf32>
- %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
- return %3 : vector<32xbf16>
- }
- ```
-
- }];
-}
-
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..1a135668a23e6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1410,6 +1410,39 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
+struct SimplifyExtFTruncFOpPair : public OpRewritePattern<ExtFOp> {
+ using OpRewritePattern<ExtFOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtFOp extFOp,
+ PatternRewriter &rewriter) const override {
+ if (auto truncFOp = extFOp.getOperand().getDefiningOp<TruncFOp>()) {
+ Value input = truncFOp.getOperand();
+ Type inTy = getElementTypeOrSelf(input.getType());
+ Type outTy = getElementTypeOrSelf(extFOp.getType());
+ Type shortTy = getElementTypeOrSelf(truncFOp.getType());
+ if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
+ (isa<Float16Type, BFloat16Type>(shortTy))) {
+ arith::FastMathFlags truncFMF = truncFOp.getFastmathAttr().getValue();
+ bool isTruncContract =
+ bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
+ arith::FastMathFlags extFMF = extFOp.getFastmathAttr().getValue();
+ bool isExtContract =
+ bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
+ if (isTruncContract && isExtContract) {
+ rewriter.replaceOp(extFOp, truncFOp.getOperand());
+ return success();
+ }
+ }
+ }
+ return failure();
+ }
+};
+
+void arith::ExtFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<SimplifyExtFTruncFOpPair>(context);
+}
+
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index a12da70bae9af..12659eaba1fa5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
BufferViewFlowOpInterfaceImpl.cpp
- EliminateExplicitRounding.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
deleted file mode 100644
index 7df77629127fa..0000000000000
--- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-//===- EliminateExplicitRounding.cpp - Remove intermediate extf/truncf pairs
-//-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements removing intermediate extf/truncf pairs inserted from
-// type conversion.
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace arith {
-#define GEN_PASS_DEF_ARITHELIMINATEEXPLICITROUNDING
-#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
-} // namespace arith
-} // namespace mlir
-
-using namespace mlir;
-
-namespace {
-
-struct EliminateExplicitRounding final
- : arith::impl::ArithEliminateExplicitRoundingBase<
- EliminateExplicitRounding> {
- using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase;
- void runOnOperation() override {
- getOperation()->walk([&](arith::ExtFOp extFOp) {
- // Check whether match `truncF->extF` pair.
- auto truncFOp = extFOp.getOperand().getDefiningOp<arith::TruncFOp>();
- if (!truncFOp)
- return;
- // Check whether the rounding pair's input and output data type are the
- // same. Currently only consider to eliminate rounding pairs for (bf16 /
- // f16 <-> f32).
- Value input = truncFOp.getOperand();
- Type inTy = getElementTypeOrSelf(input.getType());
- Type outTy = getElementTypeOrSelf(extFOp.getType());
- Type shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
- (isa<Float16Type, BFloat16Type>(shortTy))) {
- extFOp.replaceAllUsesWith(input);
- extFOp.erase();
- if (truncFOp.getResult().getUses().empty())
- truncFOp.erase();
- }
- });
- }
-};
-
-} // namespace
diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
deleted file mode 100644
index f2ba276a4f7bb..0000000000000
--- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir
+++ /dev/null
@@ -1,119 +0,0 @@
-// RUN: mlir-opt %s --split-input-file --arith-eliminate-explicit-rounding | FileCheck %s
-
-// CHECK-LABEL: @sequences
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 : f32 to bf16
- %3 = arith.extf %2 : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminatecastoncastf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 : f32 to f16
- %1 = arith.extf %0 : f16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @eliminatecastoncastbf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 : f32 to bf16
- %1 = arith.extf %0 : bf16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16>
- %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16>
- return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.cos %3 : vector<32x32x32xf32>
- %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %10 = arith.addf %6, %9 : vector<32x32x32xf32>
- %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
- %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %9 = arith.addf %8, %6 : vector<32x32x32xf32>
- %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %10 : vector<32x32x32xbf16>
-}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index d2c2c12d32389..cd06cca33c926 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1243,3 +1243,121 @@ func.func @test_materialize_failure() -> i64 {
%u = index.castu %const : index to i64
return %u: i64
}
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+ %0 = arith.extf %arg0 : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 : f32 to bf16
+ %3 = arith.extf %2 : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 : f32 to bf16
+ return %5 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to f16
+ %1 = arith.extf %0 : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ %1 = arith.extf %0 : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %5 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16>
+ %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16>
+ return %5 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.cos %3 : vector<32x32x32xf32>
+ %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %10 = arith.addf %6, %9 : vector<32x32x32xf32>
+ %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %11 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @bf16_fma
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
+func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
+ %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %9 = arith.addf %8, %6 : vector<32x32x32xf32>
+ %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %10 : vector<32x32x32xbf16>
+}
>From 5d73e9dfe4cb82137e9302526f164dbe0fe9e32d Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 7 Jun 2024 09:57:46 +0800
Subject: [PATCH 20/22] remove fastmathflags on truncf and extf
---
.../include/mlir/Conversion/LLVMCommon/VectorPattern.h | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..a7be4ff0fba7a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -10,6 +10,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -98,6 +99,15 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
+
+ // Check if the operation is remove the fastMathAttr on ExtFOp / TruncFOp.
+ if (isa<arith::ExtFOp>(op.getOperation()) ||
+ isa<arith::TruncFOp>(op.getOperation())) {
+ if (op->hasAttr("fastmath")) {
+ op->removeAttr("fastmath");
+ }
+ }
+
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
>From d53358cac11555c32e2ec533ecf6922989303645 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Sat, 8 Jun 2024 01:39:03 +0800
Subject: [PATCH 21/22] cancel default contract
---
.../Conversion/LLVMCommon/VectorPattern.h | 10 --
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 18 +--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 15 +-
.../Transforms/EmulateUnsupportedFloats.cpp | 11 +-
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 +-
mlir/test/Dialect/Arith/canonicalize.mlir | 137 ++++++++++++++++++
.../Arith/emulate-unsupported-floats.mlir | 137 +++++++++---------
mlir/test/Transforms/canonicalize.mlir | 118 ---------------
8 files changed, 238 insertions(+), 219 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index a7be4ff0fba7a..964281592cc65 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -10,7 +10,6 @@
#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -99,15 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
-
- // Check if the operation is remove the fastMathAttr on ExtFOp / TruncFOp.
- if (isa<arith::ExtFOp>(op.getOperation()) ||
- isa<arith::TruncFOp>(op.getOperation())) {
- if (op->hasAttr("fastmath")) {
- op->removeAttr("fastmath");
- }
- }
-
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6fff83dc3df7f..2e0a1d8d2f678 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1186,7 +1186,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
// ExtFOp
//===----------------------------------------------------------------------===//
-def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
@@ -1197,12 +1197,11 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
let hasFolder = 1;
let hasCanonicalizer = 1;
- let arguments = (ins FloatLike:$in, DefaultValuedAttr<
- Arith_FastMathAttr,
- "::mlir::arith::FastMathFlags::contract">:$fastmath);
+ let arguments = (ins FloatLike:$in, OptionalAttr<Arith_FastMathAttr>:$fastmath);
let results = (outs FloatLike:$out);
- let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
+ let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
@@ -1241,12 +1240,11 @@ def Arith_TruncFOp :
Arith_Op<"truncf",
[Pure, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
- DefaultValuedAttr<
- Arith_FastMathAttr,
- "::mlir::arith::FastMathFlags::contract">:$fastmath)>,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
@@ -1265,7 +1263,9 @@ def Arith_TruncFOp :
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
+ let assemblyFormat = [{ $in ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1a135668a23e6..895d72c6f9ca9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1416,16 +1416,15 @@ struct SimplifyExtFTruncFOpPair : public OpRewritePattern<ExtFOp> {
LogicalResult matchAndRewrite(ExtFOp extFOp,
PatternRewriter &rewriter) const override {
if (auto truncFOp = extFOp.getOperand().getDefiningOp<TruncFOp>()) {
- Value input = truncFOp.getOperand();
- Type inTy = getElementTypeOrSelf(input.getType());
- Type outTy = getElementTypeOrSelf(extFOp.getType());
- Type shortTy = getElementTypeOrSelf(truncFOp.getType());
- if (isa<Float32Type>(inTy) && isa<Float32Type>(outTy) &&
- (isa<Float16Type, BFloat16Type>(shortTy))) {
- arith::FastMathFlags truncFMF = truncFOp.getFastmathAttr().getValue();
+ if (truncFOp.getOperand().getType() == extFOp.getType()) {
+ // RoundingMode roundingMode =
+ // getRoundingmode().value_or(RoundingMode::to_nearest_even);
+ arith::FastMathFlags truncFMF =
+ truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
bool isTruncContract =
bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
- arith::FastMathFlags extFMF = extFOp.getFastmathAttr().getValue();
+ arith::FastMathFlags extFMF =
+ extFOp.getFastmath().value_or(arith::FastMathFlags::none);
bool isExtContract =
bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
if (isTruncContract && isExtContract) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..8e1cb474feee7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType)
- res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ if (oldType != newType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ res = truncFOp.getResult();
+ }
}
rewriter.replaceOp(op, newResults);
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..3d99f3033cf56 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
});
}
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ result = truncFOp.getResult();
+ }
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..78d12af4c3054 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3039,6 +3039,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
}
+// CHECK-LABEL: @sequences_fastmath_contract
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
+ %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
+ %3 = arith.extf %2 fastmath<contract> : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
+ return %5 : bf16
+}
+
+// CHECK-LABEL: @sequences_no_fastmath
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
+ %0 = arith.extf %arg0 : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 : f32 to bf16
+ %3 = arith.extf %2 : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 : f32 to bf16
+ return %5 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
+ %1 = arith.extf %0 fastmath<contract> : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
+ %1 = arith.extf %0 fastmath<contract> : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %5 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+ return %5 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.cos %3 : vector<32x32x32xf32>
+ %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %10 = arith.addf %6, %9 : vector<32x32x32xf32>
+ %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %11 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @bf16_fma
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
+func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
+ %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %9 = arith.addf %8, %6 : vector<32x32x32xf32>
+ %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %10 : vector<32x32x32xbf16>
+}
+
{-#
dialect_resources: {
builtin: {
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index a69ef131d8d47..75ae4168dd1b1 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,84 +4,85 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
func.return %y : bf16
}
-// -----
-
-func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
-// CHECK-LABEL: @chained
-// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
-// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
-// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
-// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
-// CHECK: return [[RES]]
- %p = arith.addf %x, %y : bf16
- %q = arith.mulf %p, %z : bf16
- %res = arith.cmpf ole, %p, %q : bf16
- func.return %res : i1
+ // -----
+
+ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
+ // CHECK-LABEL: @chained
+ // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
+ // CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+ // CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
+ // CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
+ // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
+ // CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
+ // CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
+ // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
+ // CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
+ // CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
+ // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
+ // CHECK: return [[RES]]
+ %p = arith.addf %x, %y : bf16
+ %q = arith.mulf %p, %z : bf16
+ %res = arith.cmpf ole, %p, %q : bf16
+ func.return %res : i1
}
-// -----
-
-func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
-// CHECK-LABEL: @memops
-// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
-// CHECK: memref.store [[V]]
-// CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
-// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
-// CHECK: memref.store [[X]]
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
- memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
- %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
- %x = arith.addf %v, %w : f8E4M3FNUZ
- memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
- func.return
+ // -----
+
+ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
+ // CHECK-LABEL: @memops
+ // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
+ // CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
+ // CHECK: memref.store [[V]]
+ // CHECK: [[W:%.+]] = memref.load
+ // CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
+ // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
+ // CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
+ // CHECK: memref.store [[X]]
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
+ memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
+ %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
+ %x = arith.addf %v, %w : f8E4M3FNUZ
+ memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
+ func.return
}
-// -----
-
-func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
-// CHECK-LABEL: @vectors
-// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
-// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
-// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
-// CHECK: return [[RET]]
- %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
- %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
- func.return %ret : vector<4xf32>
-}
+ // -----
+
+ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
+ // CHECK-LABEL: @vectors
+ // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
+ // CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
+ // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
+ // CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
+ // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+ // CHECK: return [[RET]]
+ %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
+ %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
+ func.return %ret : vector<4xf32>
+ }
-// -----
+ // -----
+
+ func.func @no_expansion(%x: f32) -> f32 {
+ // CHECK-LABEL: @no_expansion
+ // CHECK-SAME: [[X:%.+]]: f32
+ // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
+ // CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
+ // CHECK: return [[Y]]
+ %c = arith.constant 1.0 : f32
+ %y = arith.addf %x, %c : f32
+ func.return %y : f32
+ }
-func.func @no_expansion(%x: f32) -> f32 {
-// CHECK-LABEL: @no_expansion
-// CHECK-SAME: [[X:%.+]]: f32
-// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
-// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
-// CHECK: return [[Y]]
- %c = arith.constant 1.0 : f32
- %y = arith.addf %x, %c : f32
- func.return %y : f32
-}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index cd06cca33c926..d2c2c12d32389 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1243,121 +1243,3 @@ func.func @test_materialize_failure() -> i64 {
%u = index.castu %const : index to i64
return %u: i64
}
-
-// CHECK-LABEL: @sequences
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 : f32 to bf16
- %3 = arith.extf %2 : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminatecastoncastf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 : f32 to f16
- %1 = arith.extf %0 : f16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @eliminatecastoncastbf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 : f32 to bf16
- %1 = arith.extf %0 : bf16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16>
- %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16>
- return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.cos %3 : vector<32x32x32xf32>
- %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %10 = arith.addf %6, %9 : vector<32x32x32xf32>
- %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
- %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %9 = arith.addf %8, %6 : vector<32x32x32xf32>
- %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %10 : vector<32x32x32xbf16>
-}
>From 30e3d66ed7813226d3a8eaabf7532cecda7c03ce Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Sat, 8 Jun 2024 02:07:44 +0800
Subject: [PATCH 22/22] use folder instead
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 1 -
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 48 +++++++------------
2 files changed, 16 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 2e0a1d8d2f678..29591bab5010e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1195,7 +1195,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
}];
let hasVerifier = 1;
let hasFolder = 1;
- let hasCanonicalizer = 1;
let arguments = (ins FloatLike:$in, OptionalAttr<Arith_FastMathAttr>:$fastmath);
let results = (outs FloatLike:$out);
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 895d72c6f9ca9..d5f352bad0fa4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,6 +1390,22 @@ LogicalResult arith::ExtSIOp::verify() {
/// Fold extension of float constants when there is no information loss due the
/// difference in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
+ if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
+ if (truncFOp.getOperand().getType() == getType()) {
+ arith::FastMathFlags truncFMF =
+ truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
+ bool isTruncContract =
+ bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
+ arith::FastMathFlags extFMF =
+ getFastmath().value_or(arith::FastMathFlags::none);
+ bool isExtContract =
+ bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
+ if (isTruncContract && isExtContract) {
+ return truncFOp.getOperand();
+ }
+ }
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
@@ -1410,38 +1426,6 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
-struct SimplifyExtFTruncFOpPair : public OpRewritePattern<ExtFOp> {
- using OpRewritePattern<ExtFOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtFOp extFOp,
- PatternRewriter &rewriter) const override {
- if (auto truncFOp = extFOp.getOperand().getDefiningOp<TruncFOp>()) {
- if (truncFOp.getOperand().getType() == extFOp.getType()) {
- // RoundingMode roundingMode =
- // getRoundingmode().value_or(RoundingMode::to_nearest_even);
- arith::FastMathFlags truncFMF =
- truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
- bool isTruncContract =
- bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
- arith::FastMathFlags extFMF =
- extFOp.getFastmath().value_or(arith::FastMathFlags::none);
- bool isExtContract =
- bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
- if (isTruncContract && isExtContract) {
- rewriter.replaceOp(extFOp, truncFOp.getOperand());
- return success();
- }
- }
- }
- return failure();
- }
-};
-
-void arith::ExtFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
- MLIRContext *context) {
- patterns.add<SimplifyExtFTruncFOpPair>(context);
-}
-
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list