[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)

Rolf Morel llvmlistbot at llvm.org
Mon Sep 30 07:23:58 PDT 2024


https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/110514

Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest.

Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.

>From ea4e02bbbeb13462f400a46d6fa67b0800c2b173 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 30 Sep 2024 06:44:45 -0700
Subject: [PATCH] [MLIR][Linalg] Pattern to fold AddOp to accumulation via
 contraction op's dest

Replaces a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.

Benefits include elision of an elementwise op, namely the linalg.add,
and removing a tensor.empty as a destination which is likely to require an
allocation upon bufferization.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  11 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |   6 +
 .../TransformOps/LinalgTransformOps.cpp       |   5 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/FoldAddIntoDest.cpp     | 108 +++++++
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  75 +++++
 .../Dialect/Linalg/fold-add-into-dest.mlir    | 288 ++++++++++++++++++
 8 files changed, 498 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
 create mode 100644 mlir/test/Dialect/Linalg/fold-add-into-dest.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.fold_add_into_dest",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns to replace linalg.add when destination passing suffices
+    for achieving the sum.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
                                           const ControlFusionFn &controlFn);
 
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1e4f3004dec7e7..1d2759d2a91db1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -38,6 +38,12 @@ namespace linalg {
 // General utilities
 //===----------------------------------------------------------------------===//
 
+// Returns true if `val` represents a zero-filled tensor, per its defining op.
+bool isZeroTensor(Value val);
+
+// Returns true if the operation defines a zero-filled tensor.
+bool isZeroOp(Operation *);
+
 /// Check if all indexing maps are projected permutations.
 bool allIndexingsAreProjectedPermutation(LinalgOp op);
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ElementwiseToLinalg.cpp
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
+  FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..d8c4e338fddbbc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,108 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+  using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::AddOp addOp,
+                                PatternRewriter &rewriter) const override {
+    Value dominatingOperand = nullptr;
+    linalg::LinalgOp dominatedOp = nullptr;
+    {
+      auto firstOperand = addOp.getOperand(0);
+      auto secondOperand = addOp.getOperand(1);
+
+      // Can only put one of addOp's operands in the dest/out arg of the other's
+      // defining op based on suitable dominance.
+      if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(secondOp);
+        if (domInfo.properlyDominates(firstOperand, secondOp)) {
+          dominatingOperand = firstOperand;
+          dominatedOp = secondOp;
+        }
+      }
+      if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(firstOp);
+        if (domInfo.properlyDominates(secondOperand, firstOp)) {
+          dominatingOperand = secondOperand;
+          dominatedOp = firstOp;
+        }
+      }
+      if (!dominatingOperand || !dominatedOp)
+        return failure();
+      // NB: As linalg.add's generalisation ignores the out argument in its
+      //     region there is no need to perform checks on addOp's out argument.
+    }
+
+    // Dominated op must be a contraction for it to accumulate on its out arg.
+    // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+    auto dominatedDestOp =
+        dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+    if (dominatedOp->getNumResults() != 1 ||
+        !linalg::isaContractionOpInterface(dominatedOp) ||
+        (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op to be single-result "
+                       "destination-passing contraction");
+
+    // To change the contraction's result, `addOp` must be its only user.
+    if (!dominatedOp->getResult(0).hasOneUse())
+      return rewriter.notifyMatchFailure(
+          dominatedOp,
+          "expected linalg.add to be single user of contraction's result");
+
+    // As `dominatedOp` was already accumulating on its out argument, it is only
+    // safe to no longer use its current out arg when it is the additive zero.
+    auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+    if (!linalg::isZeroTensor(destOperand->get()))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op's dest to be additive zero");
+    // TODO: If the other op is a contraction and has additive zero as dest, we
+    // can swap the dests and achieve the proper sum, given suitable dominance.
+
+    // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+    // Hence, we can only substitute `dominatingOperand` for the dest of the
+    // contraction when dest's indexing_map corresponds to an identity map
+    // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+    SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+    int prevDimPos = -1;
+    for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+      auto dim = dyn_cast<AffineDimExpr>(expr);
+      if (!dim || prevDimPos >= (int)dim.getPosition())
+        return rewriter.notifyMatchFailure(
+            dominatedOp, "expected index_map for contraction's dest to be an "
+                         "ordered projection");
+      prevDimPos = dim.getPosition();
+    }
+
+    // Replace the additive-zero out argument of the dominated op by the
+    // dominating summand. This makes the dominated op's result the sum of both
+    // of addOp's arguments - therefore we replace addOp and it uses by it.
+    rewriter.modifyOpInPlace(
+        dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+    rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+    return success();
+  }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(
+    RewritePatternSet &patterns) {
+  // Replace linalg.add when destination passing suffices for achieving the sum.
+  patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 38e427af1c4846..a6a9ca5fd66330 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
   return reassociation;
 }
 
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+  return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+  return TypeSwitch<Attribute, bool>(attribute)
+      .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+      .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+      .Case<DenseElementsAttr>([](auto attr) {
+        if (!attr.getElementType().isIntOrFloat())
+          return false;
+        if (!attr.isSplat())
+          return false;
+        auto splat = attr.template getSplatValue<Attribute>();
+        return isZeroAttr(splat);
+      })
+      .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+  if (!val)
+    return false;
+  if (isValConstZero(val))
+    return true;
+
+  Operation *defOp = nullptr;
+
+  // Block arguments don't have a defining op, but they do have an op arg.
+  if (auto arg = dyn_cast<BlockArgument>(val)) {
+    // We need to find the argument to the linalg on the same order as this one.
+    auto *linalgOp = arg.getParentRegion()->getParentOp();
+    if (!isa<linalg::GenericOp>(linalgOp))
+      return false;
+    auto index = arg.getArgNumber();
+    auto linalgArg = linalgOp->getOperand(index);
+    defOp = linalgArg.getDefiningOp();
+  } else {
+    defOp = val.getDefiningOp();
+  }
+  return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+  if (!defOp)
+    return false;
+
+  return TypeSwitch<Operation *, bool>(defOp)
+      .Case<arith::ConstantOp>([&](auto op) {
+        // Dense attributes don't match APFloat.isZero().
+        Attribute attr = op.getValue();
+        return isZeroAttr(attr);
+      })
+      .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+        if (op.getInputs().size() != 1)
+          return false;
+        return isZeroTensor(op.getInputs()[0]);
+      })
+      .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+            tensor::ExtractSliceOp>(
+          [&](auto op) { return isZeroTensor(op.getSource()); })
+      .Case<memref::GetGlobalOp>([&](auto op) {
+        auto name = op.getName();
+        auto module = defOp->getParentOfType<ModuleOp>();
+        auto global = module.lookupSymbol<memref::GlobalOp>(name);
+        auto attr = global.getInitialValueAttr();
+        return isZeroAttr(attr);
+      })
+      .Default([&](Operation *op) { return false; });
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..4dbe253fd3221b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,288 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  %6 = linalg.mul ins(%4, %arg0 : !type, !type) outs(%1 : !type) -> !type
+  return %5, %6 : !type, !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_contraction_result_has_multiple_users
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: linalg.mul
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0)>  // NB: not an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dest_accumulation_is_not_identity_mapped(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+                        iterator_types = ["parallel", "parallel", "reduction"] }
+    ins(%arg0, %0: !type, !type) outs(%2: !type) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %5 = arith.mulf %a, %b : f32
+        %6 = arith.addf %c, %5 : f32
+        linalg.yield %6 : f32
+  } -> !type
+  %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dest_accumulation_is_not_identity_mapped
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.generic
+// CHECK: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>  // NB: is an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+                        iterator_types = ["parallel", "parallel", "reduction"] }
+    ins(%arg0, %0: !type, !type) outs(%2: !type) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %5 = arith.mulf %a, %b : f32
+        %6 = arith.addf %c, %5 : f32
+        linalg.yield %6 : f32
+  } -> !type
+  %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list