[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Oct 3 00:37:20 PDT 2024
================
@@ -0,0 +1,152 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+// Determine whether the value is defined to be zero.
+static bool isDefinedAsZero(Value val) {
+ if (!val)
+ return false;
+
+ // Check whether val is a constant scalar / vector splat / tensor splat float
+ // or integer zero.
+ if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+ return true;
+
+ return TypeSwitch<Operation *, bool>(val.getDefiningOp())
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ return op && op.getInputs().size() == 1 &&
+ isDefinedAsZero(op.getInputs()[0]);
+ })
+ .Default([&](auto) { return false; });
+}
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %empty2 = tensor.empty()
+/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
+/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
+/// %out = linalg.add ins(%C, %F) outs(%empty)
+/// to:
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %out = linalg.matmul ins(%D, %E) outs(%C)
+///
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ {
+ auto firstOperand = addOp.getOperand(0);
+ auto secondOperand = addOp.getOperand(1);
+
+ // For now, pattern only applies to tensor types (memref support is TODO).
+ if (!isa<TensorType>(addOp->getResult(0).getType()) ||
+ !isa<TensorType>(firstOperand.getType()) ||
+ !isa<TensorType>(secondOperand.getType()))
+ return failure();
+
+ // Can only put one of addOp's operands in the dest/out arg of the other's
+ // defining op based on suitable dominance.
+ // TODO: Can be generalized to move ops around as long as that still
+ // respects use-def chains and doesn't affect side-effects.
+ if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(secondOp);
+ if (domInfo.properlyDominates(firstOperand, secondOp)) {
+ dominatingOperand = firstOperand;
+ dominatedOp = secondOp;
+ }
+ }
+ if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(firstOp);
+ if (domInfo.properlyDominates(secondOperand, firstOp)) {
+ dominatingOperand = secondOperand;
+ dominatedOp = firstOp;
+ }
+ }
+ if (!dominatingOperand || !dominatedOp)
+ return failure();
+ // NB: As linalg.add's generalisation ignores the out argument in its
+ // region there is no need to perform checks on addOp's out argument.
+ }
+
+ // When dominated op is a contraction we know it accumulates on its out arg.
+ // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+ // TODO: Generalize check to also pass in case of other LinalgOps that
+ // accumulate on their out arg but are not (binary) contraction ops.
+ auto dominatedDestOp =
+ dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+ if (dominatedOp->getNumResults() != 1 ||
+ !linalg::isaContractionOpInterface(dominatedOp) ||
+ (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op to be single-result "
+ "destination-passing contraction");
+
+ // To change the contraction's result, `addOp` must be its only user.
+ if (!dominatedOp->getResult(0).hasOneUse())
+ return rewriter.notifyMatchFailure(
+ dominatedOp,
+ "expected linalg.add to be single user of contraction's result");
+
+ // As `dominatedOp` was already accumulating on its out argument, it is only
+ // safe to no longer use its current out arg when it is the additive ident.
+ auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+ if (!isDefinedAsZero(destOperand->get()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op's dest to be additive zero");
+ // TODO: If the other op is a contraction and has additive ident as dest, we
+ // can swap the dests and achieve the proper sum, given suitable dominance.
+
+ // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+ // Hence, we can only substitute `dominatingOperand` for the dest of the
+ // contraction when dest's indexing_map corresponds to an identity map
+ // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+ SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+ int prevDimPos = -1;
+ for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (!dim || prevDimPos >= (int)dim.getPosition())
----------------
adam-smnk wrote:
nit: static_cast
https://github.com/llvm/llvm-project/pull/110514
More information about the Mlir-commits
mailing list