[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
Rolf Morel
llvmlistbot at llvm.org
Thu Oct 3 03:08:20 PDT 2024
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/110514
>From 8636f5f2dea1383a20f719c85516815a4d3c5655 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 30 Sep 2024 06:44:45 -0700
Subject: [PATCH] [MLIR][Linalg] Pattern to fold AddOp to accumulation via
contraction op's dest
Replaces a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.
Benefits include elision of an elementwise op, namely the linalg.add,
and removing a tensor.empty as a destination which is likely to require an
allocation upon bufferization.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 11 +
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
.../TransformOps/LinalgTransformOps.cpp | 5 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/FoldAddIntoDest.cpp | 150 ++++++++
.../Dialect/Linalg/fold-add-into-dest.mlir | 329 ++++++++++++++++++
6 files changed, 500 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
create mode 100644 mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.linalg.fold_add_into_dest",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns to replace linalg.add when destination passing suffices
+ for achieving the sum.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseToLinalg.cpp
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
+ FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..e940b0787043eb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,150 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+// Determine whether the value is defined to be zero.
+static bool isDefinedAsZero(Value val) {
+ if (!val)
+ return false;
+
+ // Check whether val is a constant scalar / vector splat / tensor splat float
+ // or integer zero.
+ if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+ return true;
+
+ return TypeSwitch<Operation *, bool>(val.getDefiningOp())
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ return op && op.getInputs().size() == 1 &&
+ isDefinedAsZero(op.getInputs()[0]);
+ })
+ .Default([&](auto) { return false; });
+}
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %empty2 = tensor.empty()
+/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
+/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
+/// %out = linalg.add ins(%C, %F) outs(%empty)
+/// to:
+/// %cst = arith.constant 0.000000e+00
+/// %empty = tensor.empty()
+/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
+/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
+/// %out = linalg.matmul ins(%D, %E) outs(%C)
+///
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ // For now, pattern only applies on tensor types (memref support is TODO).
+ if (!addOp.hasPureTensorSemantics())
+ return failure();
+
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ { // We will forget about which operand was left or right after this block.
+ Value lhs = addOp.getInputs()[0];
+ Value rhs = addOp.getInputs()[1];
+
+ // Can only put one of addOp's operands in the dest/out arg of the other's
+ // defining op based on suitable dominance.
+ // TODO: Can be generalized to move ops around as long as that still
+ // respects use-def chains and doesn't affect side-effects.
+ if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(rhsOp);
+ if (domInfo.properlyDominates(lhs, rhsOp)) {
+ dominatingOperand = lhs;
+ dominatedOp = rhsOp;
+ }
+ }
+ if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(lhsOp);
+ if (domInfo.properlyDominates(rhs, lhsOp)) {
+ dominatingOperand = rhs;
+ dominatedOp = lhsOp;
+ }
+ }
+ if (!dominatingOperand || !dominatedOp)
+ return failure();
+ // NB: As linalg.add's generalisation ignores the out argument in its
+ // region there is no need to perform checks on addOp's out argument.
+ }
+
+ // When dominated op is a contraction we know it accumulates on its out arg.
+ // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+ // TODO: Generalize check to also pass in case of other LinalgOps that
+ // accumulate on their out arg but are not (binary) contraction ops.
+ auto dominatedDestOp =
+ dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+ if (dominatedOp->getNumResults() != 1 ||
+ !linalg::isaContractionOpInterface(dominatedOp) ||
+ (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op to be single-result "
+ "destination-passing contraction");
+
+ // To change the contraction's result, `addOp` must be its only user.
+ if (!dominatedOp->getResult(0).hasOneUse())
+ return rewriter.notifyMatchFailure(
+ dominatedOp,
+ "expected linalg.add to be single user of contraction's result");
+
+ // As `dominatedOp` was already accumulating on its out argument, it is only
+ // safe to no longer use its current out arg when it is the additive ident.
+ auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+ if (!isDefinedAsZero(destOperand->get()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op's dest to be additive zero");
+ // TODO: If the other op is a contraction and has additive ident as dest, we
+ // can swap the dests and achieve the proper sum, given suitable dominance.
+
+ // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+ // Hence, we can only substitute `dominatingOperand` for the dest of the
+ // contraction when dest's indexing_map corresponds to an identity map
+ // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+ SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+ int prevDimPos = -1;
+ for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected index_map for contraction's dest to be an "
+ "ordered projection");
+ prevDimPos = dim.getPosition();
+ }
+
+ // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
+ // dominating summand. This makes the dominated op's result the sum of both
+ // of addOp's arguments - therefore we replace addOp and it uses by it.
+ rewriter.modifyOpInPlace(
+ dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+ rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+ return success();
+ }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
+ // Replace linalg.add when destination passing suffices for achieving the sum.
+ patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..d8e92e40739dce
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,329 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_two_matmuls(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = tensor.empty() : !type
+ %5 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %6 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%5 : !type) -> !type
+ %7 = linalg.add ins(%3, %6 : !type, !type) outs(%1 : !type) -> !type
+ return %7 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_two_matmuls(
+// CHECK-SAME: %[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
+// CHECK-NEXT: %[[DENSE:.*]] = arith.constant dense<1.11
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK-NEXT: %[[EMPTY:.*]] = tensor.empty()
+// CHECK-NEXT: %[[FILLED:.*]] = linalg.fill ins(%[[ZERO]] : {{.*}}) outs(%[[EMPTY]] : {{.*}})
+// CHECK-NEXT: %[[ACC:.+]] = linalg.matmul ins(%[[ARG0]], %[[DENSE]] : {{.*}}) outs(%[[FILLED]] : {{.*}})
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins(%[[ARG1]], %[[DENSE]] : {{.*}}) outs(%[[ACC]] : {{.*}})
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ %6 = linalg.mul ins(%4, %arg0 : !type, !type) outs(%1 : !type) -> !type
+ return %5, %6 : !type, !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: linalg.mul
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_matmul_and_func_arg(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_matmul_and_func_arg
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0)> // NB: not an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"] }
+ ins(%arg0, %0: !type, !type) outs(%2: !type) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %5 = arith.mulf %a, %b : f32
+ %6 = arith.addf %c, %5 : f32
+ linalg.yield %6 : f32
+ } -> !type
+ %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.generic
+// CHECK: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // NB: is an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_a_generic_and_an_argument(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"] }
+ ins(%arg0, %0: !type, !type) outs(%2: !type) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %5 = arith.mulf %a, %b : f32
+ %6 = arith.addf %c, %5 : f32
+ linalg.yield %6 : f32
+ } -> !type
+ %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_a_generic_and_an_argument
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+memref.global "private" constant @big_const : memref<2048x2048xf32> = dense<1.11111104> {alignment = 64 : i64}
+func.func @expect_no_fold_due_to_no_memref_support(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>) -> memref<2048x2048xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = memref.get_global @big_const : memref<2048x2048xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32>
+ %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32>
+ linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<2048x2048xf32>)
+ linalg.matmul ins(%arg0, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc_0 : memref<2048x2048xf32>)
+ linalg.fill ins(%cst : f32) outs(%alloc : memref<2048x2048xf32>)
+ linalg.matmul ins(%arg1, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>)
+ linalg.add ins(%alloc_0, %alloc : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>)
+ memref.dealloc %alloc_0 : memref<2048x2048xf32>
+ return %alloc : memref<2048x2048xf32>
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_due_to_no_memref_support
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list