[Mlir-commits] [mlir] [MLIR][Math] add canonicalize-f32-promotion pass (PR #92482)

Ivy Zhang llvmlistbot at llvm.org
Mon May 20 05:28:52 PDT 2024


https://github.com/crazydemo updated https://github.com/llvm/llvm-project/pull/92482

>From c4dd5ad49f64f58aa46cd1d241fab0ffa5f3b553 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 9 May 2024 14:36:51 +0800
Subject: [PATCH 1/5] add canonicalize-f32-promotion pass

---
 .../mlir/Dialect/Math/Transforms/Passes.h     |  1 +
 .../mlir/Dialect/Math/Transforms/Passes.td    | 43 +++++++++++
 .../Dialect/Math/Transforms/CMakeLists.txt    |  1 +
 .../Transforms/CanonicalizeF32Promotion.cpp   | 73 +++++++++++++++++++
 .../Math/canonicalize-f32-promotion.mlir      | 56 ++++++++++++++
 5 files changed, 174 insertions(+)
 create mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
 create mode 100644 mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 #define GEN_PASS_DECL_MATHUPLIFTTOFMA
 #define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..538dcbfbe7f77 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,47 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+  let summary = "Eliminate redundant truncf/extf pairs";
+  let description = [{
+    `legalize-to-f32` pass does f32 promotion for every op belonging to the
+    illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+    will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+    ops.
+    
+    This pass is to eliminate the redundant truncf/extf pairs.
+
+    Example:
+
+    ```mlir
+    // the initial func
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = math.absf %arg0 : vector<32xbf16>
+        %1 = math.sin %0 : vector<32xbf16>
+        return %1 : vector<32xbf16>
+      }
+    // after legalize-to-f32
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+        %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+        %4 = math.sin %3 : vector<32xf32>
+        %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+        return %5 : vector<32xbf16>
+      }
+    // after canonicalize-f32-promotion
+    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+        %1 = math.absf %0 : vector<32xf32>
+        %2 = math.sin %1 : vector<32xf32>
+        %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+        return %3 : vector<32xbf16>
+      }
+    ```
+
+  }];
+  let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
   LegalizeToF32.cpp
+  CanonicalizeF32Promotion.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..bfff17df8d7d4
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,73 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+    : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+      if (auto truncinput = innertruncop.getOperand()) {
+        auto outter_type = op.getType();
+        auto intermediate_type = innertruncop.getType();
+        auto inner_type = truncinput.getType();
+        if (outter_type.isa<ShapedType>()) {
+          outter_type = op.getType().cast<ShapedType>().getElementType();
+          intermediate_type =
+              innertruncop.getType().cast<ShapedType>().getElementType();
+          inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+        }
+        if (outter_type.isF32() &&
+            (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+            inner_type.isF32()) {
+          rewriter.replaceOp(op, {truncinput});
+        }
+      } else
+        return failure();
+    } else
+      return failure();
+    return success();
+  }
+};
+
+struct MathCanonicalizeF32Promotion final
+    : math::impl::MathCanonicalizeF32PromotionBase<
+          MathCanonicalizeF32Promotion> {
+  using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..7aad7889e2bf5
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+  %0 = math.absf %arg0 : bf16
+  %1 = math.sin %0 : bf16
+  return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to f16
+  %1 = arith.extf %0 : f16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to bf16
+  %1 = arith.extf %0 : bf16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+  %1 = math.sin %0 : vector<32x32x32xbf16>
+  return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xf16>
+  %1 = math.sin %0 : vector<32x32x32xf16>
+  return %1 : vector<32x32x32xf16>
+}

>From 02be4d6dedc81e9e5ace44829f388e36e52e0278 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 10 May 2024 11:09:31 +0800
Subject: [PATCH 2/5] add branch case

---
 .../mlir/Dialect/Math/Transforms/Passes.td     |  6 +++++-
 .../Transforms/CanonicalizeF32Promotion.cpp    |  3 +--
 .../Math/canonicalize-f32-promotion.mlir       | 18 ++++++++++++++++++
 3 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 538dcbfbe7f77..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,7 +44,11 @@ def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
     will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
     ops.
     
-    This pass is to eliminate the redundant truncf/extf pairs.
+    This pass is to eliminate the redundant truncf/extf pairs to improve
+    performance.
+
+    However, this pass may introduce numerical difference as the `f32->bf16` rounding
+    is eliminated.
 
     Example:
 
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index bfff17df8d7d4..b9b43a0887f14 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -1,5 +1,4 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs
-//----------===//
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
index 7aad7889e2bf5..127eece98cf79 100644
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -54,3 +54,21 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
   %1 = math.sin %0 : vector<32x32x32xf16>
   return %1 : vector<32x32x32xf16>
 }
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+	%1 = math.sin %0 : vector<32x32x32xbf16>
+	%2 = math.cos %0 : vector<32x32x32xbf16>
+	%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+  return %3 : vector<32x32x32xbf16>
+}

>From 07ca29dbe48d010a36fdab154687547f26a6ead5 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 17 May 2024 14:21:38 +0800
Subject: [PATCH 3/5] use single walk rather than greedy rewrite

---
 .../Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp   | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
index b9b43a0887f14..8257ddb5c2efc 100644
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -64,7 +64,12 @@ struct MathCanonicalizeF32Promotion final
     RewritePatternSet patterns(&getContext());
     patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
     FrozenRewritePatternSet patternSet(std::move(patterns));
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+    SmallVector<Operation *> ops;
+    getOperation()->walk([&](Operation *op) {
+      if (isa<arith::ExtFOp>(op))
+        ops.push_back(op);
+    });
+    if (failed(applyOpPatternsAndFold(ops, patternSet)))
       signalPassFailure();
   }
 };

>From 224e71483f8f0341b2a6aaf53313ea712958f1df Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 20 May 2024 20:11:23 +0800
Subject: [PATCH 4/5] add canonical option in legalize-to-f32

---
 .../mlir/Dialect/Math/Transforms/Passes.h     |  4 -
 .../mlir/Dialect/Math/Transforms/Passes.td    |  8 ++
 .../Dialect/Math/Transforms/LegalizeToF32.cpp | 36 ++++++++
 mlir/test/Dialect/Math/legalize-to-f32.mlir   | 84 ++++++++++++++++---
 4 files changed, 115 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f150ff6f944d2..0e9a420b4c5dd 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -15,10 +15,6 @@ namespace mlir {
 namespace math {
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-#define GEN_PASS_DECL_MATHUPLIFTTOFMA
-#define GEN_PASS_DECL_MATHLEGALIZETOF32
-#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
 } // namespace math
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 5bf5eb45f921a..687429518010b 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -34,6 +34,14 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
     that is an operation frequently implemented at low precisions.
   }];
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+  let options = [
+    Option<"useCanonicalizeF32Promotion", "use-canonicalize-f32-promotion", "bool",
+            /*default=*/"true",
+            "Eliminate the redundant truncf/extf pairs to improve performance,"
+            "while may introduce numerical difference as the f32->bf16 rounding is"
+            "eliminated.">
+  ];
+
 }
 
 def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..883238fba9fbf 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -37,6 +38,8 @@ struct LegalizeToF32RewritePattern final : ConversionPattern {
 
 struct LegalizeToF32Pass final
     : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
+  LegalizeToF32Pass() = default;
+  LegalizeToF32Pass(const mlir::math::MathLegalizeToF32Options &options) {}
   void runOnOperation() override;
 };
 } // namespace
@@ -97,6 +100,29 @@ void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
                                             patterns.getContext());
 }
 
+struct CanonicalizeF32PromotionRewritePattern final
+    : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+      if (auto truncinput = innertruncop.getOperand()) {
+        auto outterTy = getElementTypeOrSelf(op.getType());
+        auto intermediateTy = getElementTypeOrSelf(innertruncop.getType());
+        auto innerTy = getElementTypeOrSelf(truncinput.getType());
+        if (outterTy.isF32() &&
+            (intermediateTy.isF16() || intermediateTy.isBF16()) &&
+            innerTy.isF32()) {
+          rewriter.replaceOp(op, {truncinput});
+        }
+      } else
+        return failure();
+    } else
+      return failure();
+    return success();
+  }
+};
+
 void LegalizeToF32Pass::runOnOperation() {
   Operation *op = getOperation();
   MLIRContext &ctx = getContext();
@@ -109,4 +135,14 @@ void LegalizeToF32Pass::runOnOperation() {
   math::populateLegalizeToF32Patterns(patterns, typeConverter);
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     return signalPassFailure();
+  
+  if (useCanonicalizeF32Promotion) {
+    RewritePatternSet cano_patterns(&getContext());
+    cano_patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+    FrozenRewritePatternSet cano_patternSet(std::move(cano_patterns));
+    op->walk([cano_patternSet](arith::ExtFOp extop) {
+      if (failed(applyOpPatternsAndFold({extop}, cano_patternSet)))
+        extop->emitError("fail to do implicit rounding removement");
+    });
+  }
 }
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b4..1b7bb51e771fb 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32=use-canonicalize-f32-promotion=true | FileCheck %s
 
 // CHECK-LABEL: @sin
 // CHECK-SAME: ([[ARG0:%.+]]: f16)
@@ -70,16 +70,74 @@ func.func @fastmath(%arg0: f16) -> f16 {
 }
 
 // CHECK-LABEL: @sequences
-// CHECK-SAME: ([[ARG0:%.+]]: f16)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
-// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF1]] : f16
-func.func @sequences(%arg0: f16) -> f16 {
-  %0 = math.absf %arg0 : f16
-  %1 = math.sin %0 : f16
-  return %1 : f16
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+  %0 = math.absf %arg0 : bf16
+  %1 = math.sin %0 : bf16
+  return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to f16
+  %1 = arith.extf %0 : f16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 : f32 to bf16
+  %1 = arith.extf %0 : bf16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+  %1 = math.sin %0 : vector<32x32x32xbf16>
+  return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xf16>
+  %1 = math.sin %0 : vector<32x32x32xf16>
+  return %1 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = math.absf %arg0 : vector<32x32x32xbf16>
+	%1 = math.sin %0 : vector<32x32x32xbf16>
+	%2 = math.cos %0 : vector<32x32x32xbf16>
+	%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+  return %3 : vector<32x32x32xbf16>
 }

>From bf4d20207191cbb88b679da16bab7414951b9fe9 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 20 May 2024 20:22:26 +0800
Subject: [PATCH 5/5] remove single canonicalize pass

---
 .../mlir/Dialect/Math/Transforms/Passes.td    | 47 -----------
 .../Dialect/Math/Transforms/CMakeLists.txt    |  1 -
 .../Transforms/CanonicalizeF32Promotion.cpp   | 77 -------------------
 .../Math/canonicalize-f32-promotion.mlir      | 74 ------------------
 4 files changed, 199 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
 delete mode 100644 mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 687429518010b..234b4f43f08c0 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,51 +44,4 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
 
 }
 
-def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
-  let summary = "Eliminate redundant truncf/extf pairs";
-  let description = [{
-    `legalize-to-f32` pass does f32 promotion for every op belonging to the
-    illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
-    will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
-    ops.
-    
-    This pass is to eliminate the redundant truncf/extf pairs to improve
-    performance.
-
-    However, this pass may introduce numerical difference as the `f32->bf16` rounding
-    is eliminated.
-
-    Example:
-
-    ```mlir
-    // the initial func
-    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
-        %0 = math.absf %arg0 : vector<32xbf16>
-        %1 = math.sin %0 : vector<32xbf16>
-        return %1 : vector<32xbf16>
-      }
-    // after legalize-to-f32
-    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
-        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
-        %1 = math.absf %0 : vector<32xf32>
-        %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
-        %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
-        %4 = math.sin %3 : vector<32xf32>
-        %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
-        return %5 : vector<32xbf16>
-      }
-    // after canonicalize-f32-promotion
-    func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
-        %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
-        %1 = math.absf %0 : vector<32xf32>
-        %2 = math.sin %1 : vector<32xf32>
-        %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
-        return %3 : vector<32xbf16>
-      }
-    ```
-
-  }];
-  let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
-}
-
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 0d39d14925d23..2a5b4fbcb5271 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
   LegalizeToF32.cpp
-  CanonicalizeF32Promotion.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
deleted file mode 100644
index 8257ddb5c2efc..0000000000000
--- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements removing redundant extf/truncf pairs inserted from
-// LegalizeToF32.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-
-namespace {
-
-struct CanonicalizeF32PromotionRewritePattern final
-    : OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::ExtFOp op,
-                                PatternRewriter &rewriter) const final {
-    if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
-      if (auto truncinput = innertruncop.getOperand()) {
-        auto outter_type = op.getType();
-        auto intermediate_type = innertruncop.getType();
-        auto inner_type = truncinput.getType();
-        if (outter_type.isa<ShapedType>()) {
-          outter_type = op.getType().cast<ShapedType>().getElementType();
-          intermediate_type =
-              innertruncop.getType().cast<ShapedType>().getElementType();
-          inner_type = truncinput.getType().cast<ShapedType>().getElementType();
-        }
-        if (outter_type.isF32() &&
-            (intermediate_type.isF16() || intermediate_type.isBF16()) &&
-            inner_type.isF32()) {
-          rewriter.replaceOp(op, {truncinput});
-        }
-      } else
-        return failure();
-    } else
-      return failure();
-    return success();
-  }
-};
-
-struct MathCanonicalizeF32Promotion final
-    : math::impl::MathCanonicalizeF32PromotionBase<
-          MathCanonicalizeF32Promotion> {
-  using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
-    FrozenRewritePatternSet patternSet(std::move(patterns));
-    SmallVector<Operation *> ops;
-    getOperation()->walk([&](Operation *op) {
-      if (isa<arith::ExtFOp>(op))
-        ops.push_back(op);
-    });
-    if (failed(applyOpPatternsAndFold(ops, patternSet)))
-      signalPassFailure();
-  }
-};
-
-} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
deleted file mode 100644
index 127eece98cf79..0000000000000
--- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
+++ /dev/null
@@ -1,74 +0,0 @@
-// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
-
-// CHECK-LABEL: @sequences
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences(%arg0: bf16) -> bf16 {
-  %0 = math.absf %arg0 : bf16
-  %1 = math.sin %0 : bf16
-  return %1 : bf16
-}
-
-// CHECK-LABEL: @eliminatecastoncastf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 : f32 to f16
-  %1 = arith.extf %0 : f16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @eliminatecastoncastbf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 : f32 to bf16
-  %1 = arith.extf %0 : bf16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = math.absf %arg0 : vector<32x32x32xbf16>
-  %1 = math.sin %0 : vector<32x32x32xbf16>
-  return %1 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
-  %0 = math.absf %arg0 : vector<32x32x32xf16>
-  %1 = math.sin %0 : vector<32x32x32xf16>
-  return %1 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
-// CHECK: [[ADDF:%.+]] = arith.addf
-// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = math.absf %arg0 : vector<32x32x32xbf16>
-	%1 = math.sin %0 : vector<32x32x32xbf16>
-	%2 = math.cos %0 : vector<32x32x32xbf16>
-	%3 = arith.addf %1, %2 : vector<32x32x32xbf16>
-  return %3 : vector<32x32x32xbf16>
-}



More information about the Mlir-commits mailing list