[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Oct 2 14:25:02 PDT 2024


================
@@ -0,0 +1,152 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+// Determine whether the value is defined to be zero.
+static bool isDefinedAsZero(Value val) {
+  if (!val)
+    return false;
+
+  // Check whether val is a constant scalar / vector splat / tensor splat float
+  // or integer zero.
+  if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+    return true;
+
+  return TypeSwitch<Operation *, bool>(val.getDefiningOp())
+      .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+        return op && op.getInputs().size() == 1 &&
+               isDefinedAsZero(op.getInputs()[0]);
+      })
+      .Default([&](auto) { return false; });
+}
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+///
+/// As an example, the following pseudo-code will be rewritten
+///   %cst = arith.constant 0.000000e+00
+///   %empty = tensor.empty()
+///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+///   %empty2 = tensor.empty()
+///   %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
+///   %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
+///   %out = linalg.add ins(%C, %F) outs(%empty)
+/// to:
+///   %cst = arith.constant 0.000000e+00
+///   %empty = tensor.empty()
+///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+///   %out = linalg.matmul ins(%D, %E) outs(%C)
+///
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+  using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::AddOp addOp,
+                                PatternRewriter &rewriter) const override {
+    Value dominatingOperand = nullptr;
+    linalg::LinalgOp dominatedOp = nullptr;
+    {
+      auto firstOperand = addOp.getOperand(0);
+      auto secondOperand = addOp.getOperand(1);
----------------
banach-space wrote:

[nit] I always advocate for long descriptive names, but in this case short `rhs` and `lhs` might be even more descriptive :)

Could you spell out `auto`?

https://github.com/llvm/llvm-project/pull/110514


More information about the Mlir-commits mailing list