[Mlir-commits] [mlir] [MLIR] Fix incorrect memref::DimOp canonicalization, move it to tensor::DimOp (PR #84225)
Sayan Saha
llvmlistbot at llvm.org
Fri Mar 8 05:34:40 PST 2024
https://github.com/sahas3 updated https://github.com/llvm/llvm-project/pull/84225
>From 2e7aa930c5e99112d69e90c6cafaf7659e693389 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 6 Mar 2024 15:03:04 -0500
Subject: [PATCH 1/3] [BugFix] : Move DimOp canonicalization from memref to
tensor.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 -
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 -
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 --------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 27 ++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 42 ----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 80 +++++++++++++++++++
6 files changed, 106 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
MLIRContext *context = enclosingOp->getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context);
- memref::DimOp::getCanonicalizationPatterns(patterns, context);
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dim,
- PatternRewriter &rewriter) const override {
- auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
- if (!reshape)
- return failure();
-
- // Place the load directly after the reshape to ensure that the shape memref
- // was not mutated.
- rewriter.setInsertionPointAfter(reshape);
- Location loc = dim.getLoc();
- Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
- if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
- rewriter.replaceOp(dim, load);
- return success();
- }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
-}
-
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Since tensors are immutable we don't need to worry about where to place
+ // the load call
+ rewriter.setInsertionPointAfter(dim);
+ Location loc = dim.getLoc();
+ Value load =
+ rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ if (load.getType() != dim.getType())
+ load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ rewriter.replaceOp(dim, load);
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
// -----
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: memref.store
-// CHECK-NOT: memref.dim
-// CHECK: return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
- // Update the shape to test that he load ends up in the right place.
- memref.store %c3, %arg1[%c3] : memref<?xindex>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
-// CHECK-NOT: memref.dim
-// CHECK: return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+// CHECK-NOT: tensor.store
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // Update the shape to test that the load ends up in the right place.
+ tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+// CHECK: tensor.extract
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+// CHECK: scf.for
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+// CHECK: arith.muli
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+ return %dim : index
+ }
>From 0de5e5997181a2fefb5a86dda1193b1c37a18382 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 6 Mar 2024 15:03:04 -0500
Subject: [PATCH 2/3] [BugFix] : Move DimOp canonicalization from memref to
tensor.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 -
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 -
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 --------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 28 ++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 42 ----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 80 +++++++++++++++++++
6 files changed, 107 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
MLIRContext *context = enclosingOp->getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context);
- memref::DimOp::getCanonicalizationPatterns(patterns, context);
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dim,
- PatternRewriter &rewriter) const override {
- auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
- if (!reshape)
- return failure();
-
- // Place the load directly after the reshape to ensure that the shape memref
- // was not mutated.
- rewriter.setInsertionPointAfter(reshape);
- Location loc = dim.getLoc();
- Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
- if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
- rewriter.replaceOp(dim, load);
- return success();
- }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
-}
-
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..038b8c3122b527 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Since tensors are immutable we don't need to worry about where to place
+ // the extract call
+ rewriter.setInsertionPointAfter(dim);
+ Location loc = dim.getLoc();
+ Value extract =
+ rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ if (extract.getType() != dim.getType())
+ extract =
+ rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
+ rewriter.replaceOp(dim, extract);
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
// -----
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: memref.store
-// CHECK-NOT: memref.dim
-// CHECK: return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
- // Update the shape to test that he load ends up in the right place.
- memref.store %c3, %arg1[%c3] : memref<?xindex>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
-// CHECK-NOT: memref.dim
-// CHECK: return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..1ecf076b6ca2bc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+// CHECK-NOT: tensor.store
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // Update the shape to test that the load ends up in the right place.
+ tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+// CHECK: tensor.extract
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+// CHECK: scf.for
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+// CHECK: arith.muli
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+ return %dim : index
+ }
>From 1e9c352b427c369f88b74bf148d6f6ab59434b1b Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Fri, 8 Mar 2024 08:33:47 -0500
Subject: [PATCH 3/3] [Task] : Add back memref.dim canonicalization with
dominator fix.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 +
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 46 +++++++++
mlir/test/Dialect/MemRef/canonicalize.mlir | 95 +++++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 4 +-
4 files changed, 144 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2333c92fd7b12c..c71517666b609c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,6 +629,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 00b7fa122a6c96..f69a10334050b9 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,6 +1069,52 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}
+namespace {
+/// Fold dim of a memref reshape operation to a load into the reshape's shape
+/// operand.
+struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return rewriter.notifyMatchFailure(
+ dim, "Dim op is not defined by a reshape op.");
+
+ if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
+ if (auto *definingOp = dim.getIndex().getDefiningOp()) {
+ if (reshape->isBeforeInBlock(definingOp))
+ return rewriter.notifyMatchFailure(
+ dim,
+ "dim.getIndex is not defined before reshape in the same block.");
+ } // else dim.getIndex is a block argument to reshape->getBlock
+ } else if (!dim.getIndex().getParentRegion()->isProperAncestor(
+ reshape->getParentRegion()))
+ return rewriter.notifyMatchFailure(
+ dim, "dim.getIndex does not dominate reshape.");
+
+ // Place the load directly after the reshape to ensure that the shape memref
+ // was not mutated.
+ rewriter.setInsertionPointAfter(reshape);
+ Location loc = dim.getLoc();
+ Value load =
+ rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
+ if (load.getType() != dim.getType())
+ load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ rewriter.replaceOp(dim, load);
+ return success();
+ }
+};
+
+} // namespace
+
+void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<DimOfMemRefReshape>(context);
+}
+
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 0054a8ac785a89..0c6157445fbfae 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
// -----
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: memref.store
+// CHECK-NOT: memref.dim
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = memref.reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ // Update the shape to test that he load ends up in the right place.
+ memref.store %c3, %arg1[%c3] : memref<?xindex>
+ %1 = memref.dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
+// CHECK-NOT: memref.dim
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = memref.reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+ %1 = memref.dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
+// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NOT: memref.dim
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+ %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ %dim = memref.dim %reshape, %arg2 : memref<*xf32>
+ return %dim : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_for(
+// CHECK: memref.reshape
+// CHECK: memref.dim
+// CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = memref.dim %0, %arg2 : memref<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
+// CHECK: memref.reshape
+// CHECK: memref.dim
+// CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = memref.dim %reshape, %0 : memref<*xf32>
+ return %dim : index
+ }
+
+// -----
+
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1ecf076b6ca2bc..7a4cf0dd2fc50a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2294,7 +2294,7 @@ func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
// -----
-// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
@@ -2317,7 +2317,7 @@ func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) ->
// -----
-// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
More information about the Mlir-commits
mailing list