[Mlir-commits] [mlir] [MLIR][Math] add canonicalize-f32-promotion pass (PR #92482)

Mehdi Amini llvmlistbot at llvm.org
Thu May 16 21:58:29 PDT 2024


================
@@ -0,0 +1,72 @@
+//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements removing redundant extf/truncf pairs inserted from
+// LegalizeToF32.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+
+struct CanonicalizeF32PromotionRewritePattern final
+    : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    if (auto innertruncop = op.getOperand().getDefiningOp<arith::TruncFOp>()) {
+      if (auto truncinput = innertruncop.getOperand()) {
+        auto outter_type = op.getType();
+        auto intermediate_type = innertruncop.getType();
+        auto inner_type = truncinput.getType();
+        if (outter_type.isa<ShapedType>()) {
+          outter_type = op.getType().cast<ShapedType>().getElementType();
+          intermediate_type =
+              innertruncop.getType().cast<ShapedType>().getElementType();
+          inner_type = truncinput.getType().cast<ShapedType>().getElementType();
+        }
+        if (outter_type.isF32() &&
+            (intermediate_type.isF16() || intermediate_type.isBF16()) &&
+            inner_type.isF32()) {
+          rewriter.replaceOp(op, {truncinput});
+        }
+      } else
+        return failure();
+    } else
+      return failure();
+    return success();
+  }
+};
+
+struct MathCanonicalizeF32Promotion final
+    : math::impl::MathCanonicalizeF32PromotionBase<
+          MathCanonicalizeF32Promotion> {
+  using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase;
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<CanonicalizeF32PromotionRewritePattern>(&getContext());
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
----------------
joker-eph wrote:

We should not use the greedy rewriter when we don't need to iterate until fixed point over a set of patterns.
I believe a single walk here should be enough.


https://github.com/llvm/llvm-project/pull/92482


More information about the Mlir-commits mailing list