[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
Rolf Morel
llvmlistbot at llvm.org
Thu Oct 3 02:53:36 PDT 2024
================
@@ -0,0 +1,152 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+// Determine whether the value is defined to be zero.
+static bool isDefinedAsZero(Value val) {
+ if (!val)
+ return false;
+
+ // Check whether val is a constant scalar / vector splat / tensor splat float
+ // or integer zero.
+ if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+ return true;
+
+ return TypeSwitch<Operation *, bool>(val.getDefiningOp())
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ return op && op.getInputs().size() == 1 &&
+ isDefinedAsZero(op.getInputs()[0]);
+ })
+ .Default([&](auto) { return false; });
+}
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %empty2 = tensor.empty()
+/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
+/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
+/// %out = linalg.add ins(%C, %F) outs(%empty)
+/// to:
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %out = linalg.matmul ins(%D, %E) outs(%C)
+///
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ {
+ auto firstOperand = addOp.getOperand(0);
+ auto secondOperand = addOp.getOperand(1);
----------------
rolfmorel wrote:
Agreed - and done.
https://github.com/llvm/llvm-project/pull/110514
More information about the Mlir-commits
mailing list