[Mlir-commits] [mlir] 1c22802 - [MLIR][Linalg] Change insertion point for `bubbleUpPackOpThroughElemGenericOp`
Lorenzo Chelini
llvmlistbot at llvm.org
Thu Feb 23 00:25:00 PST 2023
Author: Lorenzo Chelini
Date: 2023-02-23T09:24:54+01:00
New Revision: 1c2280264058c26aee2bcd8fe4ca90d0f843cc5c
URL: https://github.com/llvm/llvm-project/commit/1c2280264058c26aee2bcd8fe4ca90d0f843cc5c
DIFF: https://github.com/llvm/llvm-project/commit/1c2280264058c26aee2bcd8fe4ca90d0f843cc5c.diff
LOG: [MLIR][Linalg] Change insertion point for `bubbleUpPackOpThroughElemGenericOp`
Currently, the insertion point for `bubbleUpPackOpThroughElemGenericOp`
is after the tensor.pack this means that the new generic will be created
right after the tensor.pack. This is inconvenient because we are moving
the position of the generic; the idea is to move pack/unpack around, not
linalg.generics. This PR changes the insertion point to preserve the
position of the generic.
Additionally, it restricts the pattern to fire if the generic has a
single user (`tensor.pack`) to avoid introducing recomputation.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D144246
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3848510cee598..bc9098037d7c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -298,10 +299,7 @@ static FailureOr<GenericOp>
bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
tensor::PackOp packOp) {
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
- if (!genericOp)
- return failure();
-
- if (!isElementwise(genericOp))
+ if (!genericOp || !isElementwise(genericOp))
return failure();
// TODO: Relax the restriction. We are able to bubble up the pack op through
@@ -309,6 +307,34 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
if (genericOp.getNumResults() != 1)
return failure();
+ // Bail-out if the result of the generic has multiple uses, as bubbling up
+ // creates recomputation if the generic has multiple users.
+ if (!genericOp->getResult(0).hasOneUse())
+ return failure();
+
+ // We want to move the pack not the generic.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(genericOp);
+
+ // We need to handle two cases:
+ // 1) The tensor.pack destination is a tensor.empty. If this is the case, we
+ // create a new tensor.empty to avoid breaking dominance, as we are moving the
+ // tensor.pack above the linalg.generic.
+ // 2) The destination is not a tensor.empty. In this case we can replace only
+ // if the destination of the tensor.pack dominates the linalg.generic.
+ Value packOpDest = packOp.getDest();
+ if (!packOpDest.hasOneUse())
+ return failure();
+ if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
+ packOpDest = rewriter.create<tensor::EmptyOp>(
+ genericOp->getLoc(), emptyOp.getMixedSizes(),
+ emptyOp.getType().getElementType());
+ } else {
+ DominanceInfo dom(genericOp);
+ if (!dom.properlyDominates(packOpDest, genericOp))
+ return failure();
+ }
+
// TODO: Add an option for allowing padding values. It could introduce
// undefined behavior if we unconditionally propagate pack op through all
// the ops. E.g., if the padding value is zero and there are division ops in
@@ -330,7 +356,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
// If it has users we need to pack the init operand too and replace the init
// with the packing result.
Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
- ? packOp.getDest()
+ ? packOpDest
: packedOutOperand;
return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 546e268c83742..7d54e28969ad3 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -621,3 +621,34 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
+ %init = tensor.empty() : tensor<128x256xi32>
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %dest = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
+ %pack = tensor.pack %elem
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 32]
+ into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
+ return %pack : tensor<4x16x16x32xi32>
+}
+
+// CHECK: func.func @would_break_dominance(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xi32>)
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32>
+// CHECK-NEXT: %[[GEN:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
+// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ALLOC]]
More information about the Mlir-commits
mailing list