[Mlir-commits] [mlir] [MLIR][Math] add canonicalize-f32-promotion pass (PR #92482)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 16 18:38:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Ivy Zhang (crazydemo)
<details>
<summary>Changes</summary>
The current `legalize-to-f32` pass does f32 promotion for every op belonging to the illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32` will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal ops.
This pass is to eliminate the redundant truncf/extf pairs.
---
Full diff: https://github.com/llvm/llvm-project/pull/92482.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1)
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.td (+47)
- (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp (+72)
- (added) mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir (+74)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..f150ff6f944d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,6 +17,7 @@ namespace math {
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
#define GEN_PASS_DECL_MATHLEGALIZETOF32
+#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda5..5bf5eb45f921a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -36,4 +36,51 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> {
+ let summary = "Eliminate redundant truncf/extf pairs";
+ let description = [{
+ `legalize-to-f32` pass does f32 promotion for every op belonging to the
+ illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32`
+ will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal
+ ops.
+
+ This pass is to eliminate the redundant truncf/extf pairs to improve
+ performance.
+
+ However, this pass may introduce numerical difference as the `f32->bf16` rounding
+ is eliminated.
+
+ Example:
+
+ ```mlir
+ // the initial func
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = math.absf %arg0 : vector<32xbf16>
+ %1 = math.sin %0 : vector<32xbf16>
+ return %1 : vector<32xbf16>
+ }
+ // after legalize-to-f32
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16>
+ %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32>
+ %4 = math.sin %3 : vector<32xf32>
+ %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16>
+ return %5 : vector<32xbf16>
+ }
+ // after canonicalize-f32-promotion
+ func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> {
+ %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32>
+ %1 = math.absf %0 : vector<32xf32>
+ %2 = math.sin %1 : vector<32xf32>
+ %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16>
+ return %3 : vector<32xbf16>
+ }
+ ```
+
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb5271..0d39d14925d23 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
LegalizeToF32.cpp
+ CanonicalizeF32Promotion.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
new file mode 100644
index 0000000000000..b9b43a0887f14
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp
@@ -0,0 +1,72 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+ : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+ if (auto truncinput = innertruncop.getOperand()) {
+ auto outter_type = op.getType();
+ auto intermediate_type = innertruncop.getType();
+ auto inner_type = truncinput.getType();
+ if (outter_type.isa<ShapedType>()) {
+ outter_type = op.getType().cast<ShapedType>().getElementType();
+ intermediate_type =
+ innertruncop.getType().cast<ShapedType>().getElementType();
+ inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+ }
+ if (outter_type.isF32() &&
+ (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+ inner_type.isF32()) {
+ rewriter.replaceOp(op, {truncinput});
+ }
+ } else
+ return failure();
+ } else
+ return failure();
+ return success();
+ }
+};
+
+struct MathCanonicalizeF32Promotion final
+ : math::impl::MathCanonicalizeF32PromotionBase<
+ MathCanonicalizeF32Promotion> {
+ using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
new file mode 100644
index 0000000000000..127eece98cf79
--- /dev/null
+++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences(%arg0: bf16) -> bf16 {
+ %0 = math.absf %arg0 : bf16
+ %1 = math.sin %0 : bf16
+ return %1 : bf16
+}
+
+// CHECK-LABEL: @eliminatecastoncastf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to f16
+ %1 = arith.extf %0 : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminatecastoncastbf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ %1 = arith.extf %0 : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ return %1 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xf16>
+ %1 = math.sin %0 : vector<32x32x32xf16>
+ return %1 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]]
+// CHECK: [[ADDF:%.+]] = arith.addf
+// CHECK: return [[ADDF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = math.absf %arg0 : vector<32x32x32xbf16>
+ %1 = math.sin %0 : vector<32x32x32xbf16>
+ %2 = math.cos %0 : vector<32x32x32xbf16>
+ %3 = arith.addf %1, %2 : vector<32x32x32xbf16>
+ return %3 : vector<32x32x32xbf16>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/92482
More information about the Mlir-commits
mailing list