[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 30 07:24:36 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest.
Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.
---
Patch is 23.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110514.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+11)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+6)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp (+108)
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+75)
- (added) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (+288)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.linalg.fold_add_into_dest",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns to replace linalg.add when destination passing suffices
+ for achieving the sum.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1e4f3004dec7e7..1d2759d2a91db1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -38,6 +38,12 @@ namespace linalg {
// General utilities
//===----------------------------------------------------------------------===//
+// Returns true if `val` represents a zero-filled tensor, per its defining op.
+bool isZeroTensor(Value val);
+
+// Returns true if the operation defines a zero-filled tensor.
+bool isZeroOp(Operation *);
+
/// Check if all indexing maps are projected permutations.
bool allIndexingsAreProjectedPermutation(LinalgOp op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseToLinalg.cpp
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
+ FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..d8c4e338fddbbc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,108 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ {
+ auto firstOperand = addOp.getOperand(0);
+ auto secondOperand = addOp.getOperand(1);
+
+ // Can only put one of addOp's operands in the dest/out arg of the other's
+ // defining op based on suitable dominance.
+ if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(secondOp);
+ if (domInfo.properlyDominates(firstOperand, secondOp)) {
+ dominatingOperand = firstOperand;
+ dominatedOp = secondOp;
+ }
+ }
+ if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(firstOp);
+ if (domInfo.properlyDominates(secondOperand, firstOp)) {
+ dominatingOperand = secondOperand;
+ dominatedOp = firstOp;
+ }
+ }
+ if (!dominatingOperand || !dominatedOp)
+ return failure();
+ // NB: As linalg.add's generalisation ignores the out argument in its
+ // region there is no need to perform checks on addOp's out argument.
+ }
+
+ // Dominated op must be a contraction for it to accumulate on its out arg.
+ // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+ auto dominatedDestOp =
+ dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+ if (dominatedOp->getNumResults() != 1 ||
+ !linalg::isaContractionOpInterface(dominatedOp) ||
+ (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op to be single-result "
+ "destination-passing contraction");
+
+ // To change the contraction's result, `addOp` must be its only user.
+ if (!dominatedOp->getResult(0).hasOneUse())
+ return rewriter.notifyMatchFailure(
+ dominatedOp,
+ "expected linalg.add to be single user of contraction's result");
+
+ // As `dominatedOp` was already accumulating on its out argument, it is only
+ // safe to no longer use its current out arg when it is the additive zero.
+ auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+ if (!linalg::isZeroTensor(destOperand->get()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op's dest to be additive zero");
+ // TODO: If the other op is a contraction and has additive zero as dest, we
+ // can swap the dests and achieve the proper sum, given suitable dominance.
+
+ // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+ // Hence, we can only substitute `dominatingOperand` for the dest of the
+ // contraction when dest's indexing_map corresponds to an identity map
+ // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+ SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+ int prevDimPos = -1;
+ for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (!dim || prevDimPos >= (int)dim.getPosition())
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected index_map for contraction's dest to be an "
+ "ordered projection");
+ prevDimPos = dim.getPosition();
+ }
+
+ // Replace the additive-zero out argument of the dominated op by the
+ // dominating summand. This makes the dominated op's result the sum of both
+ // of addOp's arguments - therefore we replace addOp and it uses by it.
+ rewriter.modifyOpInPlace(
+ dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+ rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+ return success();
+ }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(
+ RewritePatternSet &patterns) {
+ // Replace linalg.add when destination passing suffices for achieving the sum.
+ patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 38e427af1c4846..a6a9ca5fd66330 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+ return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+ return TypeSwitch<Attribute, bool>(attribute)
+ .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+ .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+ .Case<DenseElementsAttr>([](auto attr) {
+ if (!attr.getElementType().isIntOrFloat())
+ return false;
+ if (!attr.isSplat())
+ return false;
+ auto splat = attr.template getSplatValue<Attribute>();
+ return isZeroAttr(splat);
+ })
+ .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+ if (!val)
+ return false;
+ if (isValConstZero(val))
+ return true;
+
+ Operation *defOp = nullptr;
+
+ // Block arguments don't have a defining op, but they do have an op arg.
+ if (auto arg = dyn_cast<BlockArgument>(val)) {
+ // We need to find the argument to the linalg on the same order as this one.
+ auto *linalgOp = arg.getParentRegion()->getParentOp();
+ if (!isa<linalg::GenericOp>(linalgOp))
+ return false;
+ auto index = arg.getArgNumber();
+ auto linalgArg = linalgOp->getOperand(index);
+ defOp = linalgArg.getDefiningOp();
+ } else {
+ defOp = val.getDefiningOp();
+ }
+ return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+ if (!defOp)
+ return false;
+
+ return TypeSwitch<Operation *, bool>(defOp)
+ .Case<arith::ConstantOp>([&](auto op) {
+ // Dense attributes don't match APFloat.isZero().
+ Attribute attr = op.getValue();
+ return isZeroAttr(attr);
+ })
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ if (op.getInputs().size() != 1)
+ return false;
+ return isZeroTensor(op.getInputs()[0]);
+ })
+ .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+ tensor::ExtractSliceOp>(
+ [&](auto op) { return isZeroTensor(op.getSource()); })
+ .Case<memref::GetGlobalOp>([&](auto op) {
+ auto name = op.getName();
+ auto module = defOp->getParentOfType<ModuleOp>();
+ auto global = module.lookupSymbol<memref::GlobalOp>(name);
+ auto attr = global.getInitialValueAttr();
+ return isZeroAttr(attr);
+ })
+ .Default([&](Operation *op) { return false; });
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..4dbe253fd3221b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,288 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = lin...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/110514
More information about the Mlir-commits
mailing list