[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)

Rolf Morel llvmlistbot at llvm.org
Thu Oct 3 03:08:20 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/110514

>From 8636f5f2dea1383a20f719c85516815a4d3c5655 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 30 Sep 2024 06:44:45 -0700
Subject: [PATCH] [MLIR][Linalg] Pattern to fold AddOp to accumulation via
 contraction op's dest

Replaces a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.

Benefits include elision of an elementwise op, namely the linalg.add,
and removing a tensor.empty as a destination which is likely to require an
allocation upon bufferization.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  11 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 .../TransformOps/LinalgTransformOps.cpp       |   5 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/FoldAddIntoDest.cpp     | 150 ++++++++
 .../Dialect/Linalg/fold-add-into-dest.mlir    | 329 ++++++++++++++++++
 6 files changed, 500 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
 create mode 100644 mlir/test/Dialect/Linalg/fold-add-into-dest.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.fold_add_into_dest",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns to replace linalg.add when destination passing suffices
+    for achieving the sum.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
                                           const ControlFusionFn &controlFn);
 
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ElementwiseToLinalg.cpp
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
+  FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..e940b0787043eb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,150 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+// Determine whether the value is defined to be zero.
+static bool isDefinedAsZero(Value val) {
+  if (!val)
+    return false;
+
+  // Check whether val is a constant scalar / vector splat / tensor splat float
+  // or integer zero.
+  if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+    return true;
+
+  return TypeSwitch<Operation *, bool>(val.getDefiningOp())
+      .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+        return op && op.getInputs().size() == 1 &&
+               isDefinedAsZero(op.getInputs()[0]);
+      })
+      .Default([&](auto) { return false; });
+}
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+///
+/// As an example, the following pseudo-code will be rewritten
+///   %cst = arith.constant 0.000000e+00
+///   %empty = tensor.empty()
+///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+///   %empty2 = tensor.empty()
+///   %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
+///   %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
+///   %out = linalg.add ins(%C, %F) outs(%empty)
+/// to:
+///   %cst = arith.constant 0.000000e+00
+///   %empty = tensor.empty()
+///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+///   %out = linalg.matmul ins(%D, %E) outs(%C)
+///
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+  using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::AddOp addOp,
+                                PatternRewriter &rewriter) const override {
+    // For now, pattern only applies on tensor types (memref support is TODO).
+    if (!addOp.hasPureTensorSemantics())
+      return failure();
+
+    Value dominatingOperand = nullptr;
+    linalg::LinalgOp dominatedOp = nullptr;
+    { // We will forget about which operand was left or right after this block.
+      Value lhs = addOp.getInputs()[0];
+      Value rhs = addOp.getInputs()[1];
+
+      // Can only put one of addOp's operands in the dest/out arg of the other's
+      // defining op based on suitable dominance.
+      // TODO: Can be generalized to move ops around as long as that still
+      //       respects use-def chains and doesn't affect side-effects.
+      if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(rhsOp);
+        if (domInfo.properlyDominates(lhs, rhsOp)) {
+          dominatingOperand = lhs;
+          dominatedOp = rhsOp;
+        }
+      }
+      if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(lhsOp);
+        if (domInfo.properlyDominates(rhs, lhsOp)) {
+          dominatingOperand = rhs;
+          dominatedOp = lhsOp;
+        }
+      }
+      if (!dominatingOperand || !dominatedOp)
+        return failure();
+      // NB: As linalg.add's generalisation ignores the out argument in its
+      //     region there is no need to perform checks on addOp's out argument.
+    }
+
+    // When dominated op is a contraction we know it accumulates on its out arg.
+    // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+    // TODO: Generalize check to also pass in case of other LinalgOps that
+    //       accumulate on their out arg but are not (binary) contraction ops.
+    auto dominatedDestOp =
+        dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+    if (dominatedOp->getNumResults() != 1 ||
+        !linalg::isaContractionOpInterface(dominatedOp) ||
+        (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op to be single-result "
+                       "destination-passing contraction");
+
+    // To change the contraction's result, `addOp` must be its only user.
+    if (!dominatedOp->getResult(0).hasOneUse())
+      return rewriter.notifyMatchFailure(
+          dominatedOp,
+          "expected linalg.add to be single user of contraction's result");
+
+    // As `dominatedOp` was already accumulating on its out argument, it is only
+    // safe to no longer use its current out arg when it is the additive ident.
+    auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+    if (!isDefinedAsZero(destOperand->get()))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op's dest to be additive zero");
+    // TODO: If the other op is a contraction and has additive ident as dest, we
+    // can swap the dests and achieve the proper sum, given suitable dominance.
+
+    // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+    // Hence, we can only substitute `dominatingOperand` for the dest of the
+    // contraction when dest's indexing_map corresponds to an identity map
+    // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+    SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+    int prevDimPos = -1;
+    for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+      auto dim = dyn_cast<AffineDimExpr>(expr);
+      if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
+        return rewriter.notifyMatchFailure(
+            dominatedOp, "expected index_map for contraction's dest to be an "
+                         "ordered projection");
+      prevDimPos = dim.getPosition();
+    }
+
+    // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
+    // dominating summand. This makes the dominated op's result the sum of both
+    // of addOp's arguments - therefore we replace addOp and it uses by it.
+    rewriter.modifyOpInPlace(
+        dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+    rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+    return success();
+  }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
+  // Replace linalg.add when destination passing suffices for achieving the sum.
+  patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..d8e92e40739dce
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,329 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_two_matmuls(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = tensor.empty() : !type
+  %5 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %6 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%5 : !type) -> !type
+  %7 = linalg.add ins(%3, %6 : !type, !type) outs(%1 : !type) -> !type
+  return %7 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_two_matmuls(
+// CHECK-SAME: %[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
+// CHECK-NEXT: %[[DENSE:.*]] = arith.constant dense<1.11
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK-NEXT: %[[EMPTY:.*]] = tensor.empty()
+// CHECK-NEXT: %[[FILLED:.*]] = linalg.fill ins(%[[ZERO]] : {{.*}}) outs(%[[EMPTY]] : {{.*}})
+// CHECK-NEXT: %[[ACC:.+]] = linalg.matmul ins(%[[ARG0]], %[[DENSE]] : {{.*}}) outs(%[[FILLED]] : {{.*}})
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins(%[[ARG1]], %[[DENSE]] : {{.*}}) outs(%[[ACC]] : {{.*}})
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  %6 = linalg.mul ins(%4, %arg0 : !type, !type) outs(%1 : !type) -> !type
+  return %5, %6 : !type, !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: linalg.mul
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_matmul_and_func_arg(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_matmul_and_func_arg
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0)>  // NB: not an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+                        iterator_types = ["parallel", "parallel", "reduction"] }
+    ins(%arg0, %0: !type, !type) outs(%2: !type) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %5 = arith.mulf %a, %b : f32
+        %6 = arith.addf %c, %5 : f32
+        linalg.yield %6 : f32
+  } -> !type
+  %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.generic
+// CHECK: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>  // NB: is an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_a_generic_and_an_argument(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+                        iterator_types = ["parallel", "parallel", "reduction"] }
+    ins(%arg0, %0: !type, !type) outs(%2: !type) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %5 = arith.mulf %a, %b : f32
+        %6 = arith.addf %c, %5 : f32
+        linalg.yield %6 : f32
+  } -> !type
+  %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_a_generic_and_an_argument
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+memref.global "private" constant @big_const : memref<2048x2048xf32> = dense<1.11111104> {alignment = 64 : i64}
+func.func @expect_no_fold_due_to_no_memref_support(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>) -> memref<2048x2048xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = memref.get_global @big_const  : memref<2048x2048xf32>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32>
+  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32>
+  linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<2048x2048xf32>)
+  linalg.matmul ins(%arg0, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc_0 : memref<2048x2048xf32>)
+  linalg.fill ins(%cst : f32) outs(%alloc : memref<2048x2048xf32>)
+  linalg.matmul ins(%arg1, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>)
+  linalg.add ins(%alloc_0, %alloc : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>)
+  memref.dealloc %alloc_0 : memref<2048x2048xf32>
+  return %alloc : memref<2048x2048xf32>
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_due_to_no_memref_support
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list