[Mlir-commits] [mlir] [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (PR #84225)

Sayan Saha llvmlistbot at llvm.org
Fri Mar 8 13:19:17 PST 2024


https://github.com/sahas3 updated https://github.com/llvm/llvm-project/pull/84225

>From 2e7aa930c5e99112d69e90c6cafaf7659e693389 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 6 Mar 2024 15:03:04 -0500
Subject: [PATCH 1/5] [BugFix] : Move DimOp canonicalization from memref to
 tensor.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  1 -
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp  |  1 -
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 33 --------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 27 ++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 42 ----------
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 80 +++++++++++++++++++
 6 files changed, 106 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     Speculation::Speculatability getSpeculatability();
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
   MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
-  memref::DimOp::getCanonicalizationPatterns(patterns, context);
   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dim,
-                                PatternRewriter &rewriter) const override {
-    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
-    if (!reshape)
-      return failure();
-
-    // Place the load directly after the reshape to ensure that the shape memref
-    // was not mutated.
-    rewriter.setInsertionPointAfter(reshape);
-    Location loc = dim.getLoc();
-    Value load =
-        rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
-    if (load.getType() != dim.getType())
-      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
-    rewriter.replaceOp(dim, load);
-    return success();
-  }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
-}
-
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Since tensors are immutable we don't need to worry about where to place
+    // the load call
+    rewriter.setInsertionPointAfter(dim);
+    Location loc = dim.getLoc();
+    Value load =
+        rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+    if (load.getType() != dim.getType())
+      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+    rewriter.replaceOp(dim, load);
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
 
 // -----
 
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   memref.store
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
-  // Update the shape to test that he load ends up in the right place.
-  memref.store %c3, %arg1[%c3] : memref<?xindex>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK:         return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   tensor.store
+//   CHECK-NOT:   tensor.dim
+//   CHECK-NOT: tensor.reshape
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = tensor.reshape %arg0(%arg1)
+      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Update the shape to test that the load ends up in the right place.
+  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+  %1 = tensor.dim %0, %c3 : tensor<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+//       CHECK:  tensor.extract
+//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
+//   CHECK-NOT:  tensor.dim
+//   CHECK-NOT:  tensor.reshape
+//       CHECK:  return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+    -> index {
+    %c3 = arith.constant 3 : index
+    %0 = tensor.reshape %arg0(%arg1)
+        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    %1 = tensor.dim %0, %c3 : tensor<*xf32>
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+//       CHECK: scf.for
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+//       CHECK: arith.muli
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+    return %dim : index
+  }

>From 0de5e5997181a2fefb5a86dda1193b1c37a18382 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 6 Mar 2024 15:03:04 -0500
Subject: [PATCH 2/5] [BugFix] : Move DimOp canonicalization from memref to
 tensor.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  1 -
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp  |  1 -
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 33 --------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 28 ++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 42 ----------
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 80 +++++++++++++++++++
 6 files changed, 107 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     Speculation::Speculatability getSpeculatability();
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
   MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
-  memref::DimOp::getCanonicalizationPatterns(patterns, context);
   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dim,
-                                PatternRewriter &rewriter) const override {
-    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
-    if (!reshape)
-      return failure();
-
-    // Place the load directly after the reshape to ensure that the shape memref
-    // was not mutated.
-    rewriter.setInsertionPointAfter(reshape);
-    Location loc = dim.getLoc();
-    Value load =
-        rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
-    if (load.getType() != dim.getType())
-      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
-    rewriter.replaceOp(dim, load);
-    return success();
-  }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
-}
-
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..038b8c3122b527 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Since tensors are immutable we don't need to worry about where to place
+    // the extract call
+    rewriter.setInsertionPointAfter(dim);
+    Location loc = dim.getLoc();
+    Value extract =
+        rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+    if (extract.getType() != dim.getType())
+      extract =
+          rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
+    rewriter.replaceOp(dim, extract);
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
 
 // -----
 
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   memref.store
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
-  // Update the shape to test that he load ends up in the right place.
-  memref.store %c3, %arg1[%c3] : memref<?xindex>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..1ecf076b6ca2bc 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK:         return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   tensor.store
+//   CHECK-NOT:   tensor.dim
+//   CHECK-NOT: tensor.reshape
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = tensor.reshape %arg0(%arg1)
+      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Update the shape to test that the load ends up in the right place.
+  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+  %1 = tensor.dim %0, %c3 : tensor<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+//       CHECK:  tensor.extract
+//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
+//   CHECK-NOT:  tensor.dim
+//   CHECK-NOT:  tensor.reshape
+//       CHECK:  return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+    -> index {
+    %c3 = arith.constant 3 : index
+    %0 = tensor.reshape %arg0(%arg1)
+        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    %1 = tensor.dim %0, %c3 : tensor<*xf32>
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+//       CHECK: scf.for
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+//       CHECK: arith.muli
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+    return %dim : index
+  }

>From 1e9c352b427c369f88b74bf148d6f6ab59434b1b Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Fri, 8 Mar 2024 08:33:47 -0500
Subject: [PATCH 3/5] [Task] : Add back memref.dim canonicalization with
 dominator fix.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  1 +
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 46 +++++++++
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 95 +++++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  4 +-
 4 files changed, 144 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2333c92fd7b12c..c71517666b609c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,6 +629,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     Speculation::Speculatability getSpeculatability();
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 00b7fa122a6c96..f69a10334050b9 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,6 +1069,52 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+namespace {
+/// Fold dim of a memref reshape operation to a load into the reshape's shape
+/// operand.
+struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return rewriter.notifyMatchFailure(
+          dim, "Dim op is not defined by a reshape op.");
+
+    if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
+      if (auto *definingOp = dim.getIndex().getDefiningOp()) {
+        if (reshape->isBeforeInBlock(definingOp))
+          return rewriter.notifyMatchFailure(
+              dim,
+              "dim.getIndex is not defined before reshape in the same block.");
+      } // else dim.getIndex is a block argument to reshape->getBlock
+    } else if (!dim.getIndex().getParentRegion()->isProperAncestor(
+                   reshape->getParentRegion()))
+      return rewriter.notifyMatchFailure(
+          dim, "dim.getIndex does not dominate reshape.");
+
+    // Place the load directly after the reshape to ensure that the shape memref
+    // was not mutated.
+    rewriter.setInsertionPointAfter(reshape);
+    Location loc = dim.getLoc();
+    Value load =
+        rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
+    if (load.getType() != dim.getType())
+      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+    rewriter.replaceOp(dim, load);
+    return success();
+  }
+};
+
+} // namespace
+
+void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  results.add<DimOfMemRefReshape>(context);
+}
+
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 0054a8ac785a89..0c6157445fbfae 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
 
 // -----
 
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   memref.store
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = memref.reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+  // Update the shape to test that he load ends up in the right place.
+  memref.store %c3, %arg1[%c3] : memref<?xindex>
+  %1 = memref.dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[CAST]] : index
+func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = memref.reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+  %1 = memref.dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
+//  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:   %[[SHP:[0-9a-z]+]]: memref<?xindex>,
+//  CHECK-SAME:   %[[IDX:[0-9a-z]+]]: index
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+  %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+  %dim = memref.dim %reshape, %arg2 : memref<*xf32>
+  return %dim : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_for(
+//       CHECK: memref.reshape
+//       CHECK: memref.dim
+//   CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = memref.dim %0, %arg2 : memref<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
+//       CHECK: memref.reshape
+//       CHECK: memref.dim
+//   CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = memref.dim %reshape, %0 : memref<*xf32>
+    return %dim : index
+  }
+
+// -----
+
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1ecf076b6ca2bc..7a4cf0dd2fc50a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2294,7 +2294,7 @@ func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
 
 // -----
 
-// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
 // CHECK-LABEL: func @dim_of_reshape_for(
 //       CHECK: scf.for
 //  CHECK-NEXT: tensor.extract
@@ -2317,7 +2317,7 @@ func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) ->
 
 // -----
 
-// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
 // CHECK-LABEL: func @dim_of_reshape_undominated(
 //       CHECK: arith.muli
 //  CHECK-NEXT: tensor.extract

>From e50232affc5a111db92b63df30031031b6ce7bc4 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Fri, 8 Mar 2024 09:40:32 -0500
Subject: [PATCH 4/5] [Task] : Add back memref.dim canonicalization to
 Loops.cpp.

---
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index e1cb5b477debbc..b0a4de2da1e869 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,6 +317,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
   MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
+  memref::DimOp::getCanonicalizationPatterns(patterns, context);
   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);

>From d7f3bd23f7e4c18d0c5598deb950d76ed002b73f Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Fri, 8 Mar 2024 15:53:57 -0500
Subject: [PATCH 5/5] [Task] : Add comments + enhance check for index in parent
 block of reshape.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 25 ++++++++++++++++++++----
 1 file changed, 21 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f69a10334050b9..3594b9669e3c6d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
       return rewriter.notifyMatchFailure(
           dim, "Dim op is not defined by a reshape op.");
 
+    // dim of a memref reshape can be folded if dim.getIndex() dominates the
+    // reshape. Instead of using `DominanceInfo` (which is usually costly) we
+    // cheaply check that either of the following conditions hold:
+    //      1. dim.getIndex() is defined in the same block as reshape but before
+    //      reshape.
+    //      2. dim.getIndex() is defined in a parent block of
+    //      reshape.
+
+    // Check condition 1
     if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
       if (auto *definingOp = dim.getIndex().getDefiningOp()) {
-        if (reshape->isBeforeInBlock(definingOp))
+        if (reshape->isBeforeInBlock(definingOp)) {
           return rewriter.notifyMatchFailure(
               dim,
               "dim.getIndex is not defined before reshape in the same block.");
-      } // else dim.getIndex is a block argument to reshape->getBlock
-    } else if (!dim.getIndex().getParentRegion()->isProperAncestor(
-                   reshape->getParentRegion()))
+        }
+      } // else dim.getIndex is a block argument to reshape->getBlock and
+        // dominates reshape
+    } // Check condition 2
+    else if (dim->getBlock() != reshape->getBlock() &&
+             !dim.getIndex().getParentRegion()->isProperAncestor(
+                 reshape->getParentRegion())) {
+      // If dim and reshape are in the same block but dim.getIndex() isn't, we
+      // already know dim.getIndex() dominates reshape without calling
+      // `isProperAncestor`
       return rewriter.notifyMatchFailure(
           dim, "dim.getIndex does not dominate reshape.");
+    }
 
     // Place the load directly after the reshape to ensure that the shape memref
     // was not mutated.



More information about the Mlir-commits mailing list