[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
Rolf Morel
llvmlistbot at llvm.org
Mon Sep 30 07:23:58 PDT 2024
https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/110514
Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest.
Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.
>From ea4e02bbbeb13462f400a46d6fa67b0800c2b173 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 30 Sep 2024 06:44:45 -0700
Subject: [PATCH] [MLIR][Linalg] Pattern to fold AddOp to accumulation via
contraction op's dest
Replaces a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.
Benefits include elision of an elementwise op, namely the linalg.add,
and removing a tensor.empty as a destination which is likely to require an
allocation upon bufferization.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 11 +
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 6 +
.../TransformOps/LinalgTransformOps.cpp | 5 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/FoldAddIntoDest.cpp | 108 +++++++
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 75 +++++
.../Dialect/Linalg/fold-add-into-dest.mlir | 288 ++++++++++++++++++
8 files changed, 498 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
create mode 100644 mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.linalg.fold_add_into_dest",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns to replace linalg.add when destination passing suffices
+ for achieving the sum.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1e4f3004dec7e7..1d2759d2a91db1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -38,6 +38,12 @@ namespace linalg {
// General utilities
//===----------------------------------------------------------------------===//
+// Returns true if `val` represents a zero-filled tensor, per its defining op.
+bool isZeroTensor(Value val);
+
+// Returns true if the operation defines a zero-filled tensor.
+bool isZeroOp(Operation *);
+
/// Check if all indexing maps are projected permutations.
bool allIndexingsAreProjectedPermutation(LinalgOp op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseToLinalg.cpp
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
+ FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..d8c4e338fddbbc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,108 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ {
+ auto firstOperand = addOp.getOperand(0);
+ auto secondOperand = addOp.getOperand(1);
+
+ // Can only put one of addOp's operands in the dest/out arg of the other's
+ // defining op based on suitable dominance.
+ if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(secondOp);
+ if (domInfo.properlyDominates(firstOperand, secondOp)) {
+ dominatingOperand = firstOperand;
+ dominatedOp = secondOp;
+ }
+ }
+ if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(firstOp);
+ if (domInfo.properlyDominates(secondOperand, firstOp)) {
+ dominatingOperand = secondOperand;
+ dominatedOp = firstOp;
+ }
+ }
+ if (!dominatingOperand || !dominatedOp)
+ return failure();
+ // NB: As linalg.add's generalisation ignores the out argument in its
+ // region there is no need to perform checks on addOp's out argument.
+ }
+
+ // Dominated op must be a contraction for it to accumulate on its out arg.
+ // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+ auto dominatedDestOp =
+ dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+ if (dominatedOp->getNumResults() != 1 ||
+ !linalg::isaContractionOpInterface(dominatedOp) ||
+ (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op to be single-result "
+ "destination-passing contraction");
+
+ // To change the contraction's result, `addOp` must be its only user.
+ if (!dominatedOp->getResult(0).hasOneUse())
+ return rewriter.notifyMatchFailure(
+ dominatedOp,
+ "expected linalg.add to be single user of contraction's result");
+
+ // As `dominatedOp` was already accumulating on its out argument, it is only
+ // safe to no longer use its current out arg when it is the additive zero.
+ auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+ if (!linalg::isZeroTensor(destOperand->get()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op's dest to be additive zero");
+ // TODO: If the other op is a contraction and has additive zero as dest, we
+ // can swap the dests and achieve the proper sum, given suitable dominance.
+
+ // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+ // Hence, we can only substitute `dominatingOperand` for the dest of the
+ // contraction when dest's indexing_map corresponds to an identity map
+ // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+ SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+ int prevDimPos = -1;
+ for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (!dim || prevDimPos >= (int)dim.getPosition())
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected index_map for contraction's dest to be an "
+ "ordered projection");
+ prevDimPos = dim.getPosition();
+ }
+
+ // Replace the additive-zero out argument of the dominated op by the
+ // dominating summand. This makes the dominated op's result the sum of both
+ // of addOp's arguments - therefore we replace addOp and it uses by it.
+ rewriter.modifyOpInPlace(
+ dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+ rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+ return success();
+ }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(
+ RewritePatternSet &patterns) {
+ // Replace linalg.add when destination passing suffices for achieving the sum.
+ patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 38e427af1c4846..a6a9ca5fd66330 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+ return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+ return TypeSwitch<Attribute, bool>(attribute)
+ .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+ .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+ .Case<DenseElementsAttr>([](auto attr) {
+ if (!attr.getElementType().isIntOrFloat())
+ return false;
+ if (!attr.isSplat())
+ return false;
+ auto splat = attr.template getSplatValue<Attribute>();
+ return isZeroAttr(splat);
+ })
+ .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+ if (!val)
+ return false;
+ if (isValConstZero(val))
+ return true;
+
+ Operation *defOp = nullptr;
+
+ // Block arguments don't have a defining op, but they do have an op arg.
+ if (auto arg = dyn_cast<BlockArgument>(val)) {
+ // We need to find the argument to the linalg on the same order as this one.
+ auto *linalgOp = arg.getParentRegion()->getParentOp();
+ if (!isa<linalg::GenericOp>(linalgOp))
+ return false;
+ auto index = arg.getArgNumber();
+ auto linalgArg = linalgOp->getOperand(index);
+ defOp = linalgArg.getDefiningOp();
+ } else {
+ defOp = val.getDefiningOp();
+ }
+ return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+ if (!defOp)
+ return false;
+
+ return TypeSwitch<Operation *, bool>(defOp)
+ .Case<arith::ConstantOp>([&](auto op) {
+ // Dense attributes don't match APFloat.isZero().
+ Attribute attr = op.getValue();
+ return isZeroAttr(attr);
+ })
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ if (op.getInputs().size() != 1)
+ return false;
+ return isZeroTensor(op.getInputs()[0]);
+ })
+ .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+ tensor::ExtractSliceOp>(
+ [&](auto op) { return isZeroTensor(op.getSource()); })
+ .Case<memref::GetGlobalOp>([&](auto op) {
+ auto name = op.getName();
+ auto module = defOp->getParentOfType<ModuleOp>();
+ auto global = module.lookupSymbol<memref::GlobalOp>(name);
+ auto attr = global.getInitialValueAttr();
+ return isZeroAttr(attr);
+ })
+ .Default([&](Operation *op) { return false; });
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..4dbe253fd3221b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,288 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ %6 = linalg.mul ins(%4, %arg0 : !type, !type) outs(%1 : !type) -> !type
+ return %5, %6 : !type, !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_contraction_result_has_multiple_users
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: linalg.mul
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0)> // NB: not an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dest_accumulation_is_not_identity_mapped(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"] }
+ ins(%arg0, %0: !type, !type) outs(%2: !type) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %5 = arith.mulf %a, %b : f32
+ %6 = arith.addf %c, %5 : f32
+ linalg.yield %6 : f32
+ } -> !type
+ %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dest_accumulation_is_not_identity_mapped
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.generic
+// CHECK: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // NB: is an ordered projection
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"] }
+ ins(%arg0, %0: !type, !type) outs(%2: !type) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %5 = arith.mulf %a, %b : f32
+ %6 = arith.addf %c, %5 : f32
+ linalg.yield %6 : f32
+ } -> !type
+ %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.add
+// CHECK: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list