[Mlir-commits] [mlir] [BugFix] : Move DimOp canonicalization from memref to tensor. (PR #84225)
Sayan Saha
llvmlistbot at llvm.org
Wed Mar 6 12:15:14 PST 2024
https://github.com/sahas3 created https://github.com/llvm/llvm-project/pull/84225
The current canonicalization of `memref.dim` operating on the result of `memref.reshape` into `memref.load` is incorrect as it doesn't check whether the `index` operand of `memref.dim` dominates the source `memref.reshape` op. It always introduces `memref.load` right after `memref.reshape` to ensure the `memref` is not mutated before the `memref.load` call. As a result, the following error is observed:
```
$> mlir-opt --canonicalize input.mlir
func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}
```
results in:
```
dominator.mlir:22:12: error: operand #1 does not dominate this use
%dim = memref.dim %reshape, %0 : memref<*xf32>
^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
%0 = arith.muli %arg2, %c4 : index
```
Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to `tensor.dim`. Since tensors are immutable we don't need to worry about where to introduce the `tensor.extract` call after canonicalization.
>From 2e7aa930c5e99112d69e90c6cafaf7659e693389 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 6 Mar 2024 15:03:04 -0500
Subject: [PATCH] [BugFix] : Move DimOp canonicalization from memref to tensor.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 -
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 -
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 --------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 27 ++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 42 ----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 80 +++++++++++++++++++
6 files changed, 106 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
MLIRContext *context = enclosingOp->getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context);
- memref::DimOp::getCanonicalizationPatterns(patterns, context);
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dim,
- PatternRewriter &rewriter) const override {
- auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
- if (!reshape)
- return failure();
-
- // Place the load directly after the reshape to ensure that the shape memref
- // was not mutated.
- rewriter.setInsertionPointAfter(reshape);
- Location loc = dim.getLoc();
- Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
- if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
- rewriter.replaceOp(dim, load);
- return success();
- }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
-}
-
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Since tensors are immutable we don't need to worry about where to place
+ // the load call
+ rewriter.setInsertionPointAfter(dim);
+ Location loc = dim.getLoc();
+ Value load =
+ rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ if (load.getType() != dim.getType())
+ load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ rewriter.replaceOp(dim, load);
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
// -----
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: memref.store
-// CHECK-NOT: memref.dim
-// CHECK: return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
- // Update the shape to test that he load ends up in the right place.
- memref.store %c3, %arg1[%c3] : memref<?xindex>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
-// CHECK-NOT: memref.dim
-// CHECK: return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+// CHECK-NOT: tensor.store
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // Update the shape to test that the load ends up in the right place.
+ tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+// CHECK: tensor.extract
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+// CHECK: scf.for
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+// CHECK: arith.muli
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+ return %dim : index
+ }
More information about the Mlir-commits
mailing list