[Mlir-commits] [mlir] [BugFix] : Move DimOp canonicalization from memref to tensor. (PR #84225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 6 12:16:04 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Sayan Saha (sahas3)

<details>
<summary>Changes</summary>

The current canonicalization of `memref.dim` operating on the result of `memref.reshape` into `memref.load` is incorrect as it doesn't check whether the `index` operand of `memref.dim` dominates the source `memref.reshape` op. It always introduces `memref.load` right after `memref.reshape` to ensure the `memref` is not mutated before the `memref.load` call. As a result, the following error is observed:

```
$> mlir-opt --canonicalize input.mlir

func.func @<!-- -->reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
    %c4 = arith.constant 4 : index
    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
    %0 = arith.muli %arg2, %c4 : index
    %dim = memref.dim %reshape, %0 : memref<*xf32>
    return %dim : index
  }
```

results in:

```
dominator.mlir:22:12: error: operand #<!-- -->1 does not dominate this use
    %dim = memref.dim %reshape, %0 : memref<*xf32>
           ^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
    %0 = arith.muli %arg2, %c4 : index
```

Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to `tensor.dim`. Since tensors are immutable we don't need to worry about where to introduce the `tensor.extract` call after canonicalization.

---
Full diff: https://github.com/llvm/llvm-project/pull/84225.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Loops.cpp (-1) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (-33) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+26-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (-42) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+80) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     Speculation::Speculatability getSpeculatability();
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
   MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
-  memref::DimOp::getCanonicalizationPatterns(patterns, context);
   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dim,
-                                PatternRewriter &rewriter) const override {
-    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
-    if (!reshape)
-      return failure();
-
-    // Place the load directly after the reshape to ensure that the shape memref
-    // was not mutated.
-    rewriter.setInsertionPointAfter(reshape);
-    Location loc = dim.getLoc();
-    Value load =
-        rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
-    if (load.getType() != dim.getType())
-      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
-    rewriter.replaceOp(dim, load);
-    return success();
-  }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
-}
-
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Since tensors are immutable we don't need to worry about where to place
+    // the load call
+    rewriter.setInsertionPointAfter(dim);
+    Location loc = dim.getLoc();
+    Value load =
+        rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+    if (load.getType() != dim.getType())
+      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+    rewriter.replaceOp(dim, load);
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
 
 // -----
 
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   memref.store
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
-  // Update the shape to test that he load ends up in the right place.
-  memref.store %c3, %arg1[%c3] : memref<?xindex>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK:         return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   tensor.store
+//   CHECK-NOT:   tensor.dim
+//   CHECK-NOT: tensor.reshape
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = tensor.reshape %arg0(%arg1)
+      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Update the shape to test that the load ends up in the right place.
+  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+  %1 = tensor.dim %0, %c3 : tensor<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+//       CHECK:  tensor.extract
+//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
+//   CHECK-NOT:  tensor.dim
+//   CHECK-NOT:  tensor.reshape
+//       CHECK:  return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+    -> index {
+    %c3 = arith.constant 3 : index
+    %0 = tensor.reshape %arg0(%arg1)
+        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    %1 = tensor.dim %0, %c3 : tensor<*xf32>
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+//       CHECK: scf.for
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+//       CHECK: arith.muli
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+    return %dim : index
+  }

``````````

</details>


https://github.com/llvm/llvm-project/pull/84225


More information about the Mlir-commits mailing list