[Mlir-commits] [mlir] 26722f5 - [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (#84225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 11 19:37:37 PDT 2024


Author: Sayan Saha
Date: 2024-03-11T19:37:33-07:00
New Revision: 26722f5b61575fb0e58ff2933e7bea03353ff441

URL: https://github.com/llvm/llvm-project/commit/26722f5b61575fb0e58ff2933e7bea03353ff441
DIFF: https://github.com/llvm/llvm-project/commit/26722f5b61575fb0e58ff2933e7bea03353ff441.diff

LOG: [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (#84225)

The current canonicalization of `memref.dim` operating on the result of
`memref.reshape` into `memref.load` is incorrect as it doesn't check
whether the `index` operand of `memref.dim` dominates the source
`memref.reshape` op. It always introduces `memref.load` right after
`memref.reshape` to ensure the `memref` is not mutated before the
`memref.load` call. As a result, the following error is observed:

```
$> mlir-opt --canonicalize input.mlir

func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
    %c4 = arith.constant 4 : index
    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
    %0 = arith.muli %arg2, %c4 : index
    %dim = memref.dim %reshape, %0 : memref<*xf32>
    return %dim : index
  }
```

results in:

```
dominator.mlir:22:12: error: operand #1 does not dominate this use
    %dim = memref.dim %reshape, %0 : memref<*xf32>
           ^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
    %0 = arith.muli %arg2, %c4 : index
```

Properly fixing this issue requires a dominator analysis which is
expensive to run within a canonicalization pattern. So, this patch fixes
the canonicalization pattern by being more strict/conservative about the
legality condition in which we perform this canonicalization.
The more general pattern is also added to `tensor.dim`. Since tensors are
immutable we don't need to worry about where to introduce the
`tensor.extract` call after canonicalization.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94e0ed319cae83..836dcb8f329e70 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
     auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
 
     if (!reshape)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          dim, "Dim op is not defined by a reshape op.");
+
+    // dim of a memref reshape can be folded if dim.getIndex() dominates the
+    // reshape. Instead of using `DominanceInfo` (which is usually costly) we
+    // cheaply check that either of the following conditions hold:
+    //      1. dim.getIndex() is defined in the same block as reshape but before
+    //      reshape.
+    //      2. dim.getIndex() is defined in a parent block of
+    //      reshape.
+
+    // Check condition 1
+    if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
+      if (auto *definingOp = dim.getIndex().getDefiningOp()) {
+        if (reshape->isBeforeInBlock(definingOp)) {
+          return rewriter.notifyMatchFailure(
+              dim,
+              "dim.getIndex is not defined before reshape in the same block.");
+        }
+      } // else dim.getIndex is a block argument to reshape->getBlock and
+        // dominates reshape
+    }   // Check condition 2
+    else if (dim->getBlock() != reshape->getBlock() &&
+             !dim.getIndex().getParentRegion()->isProperAncestor(
+                 reshape->getParentRegion())) {
+      // If dim and reshape are in the same block but dim.getIndex() isn't, we
+      // already know dim.getIndex() dominates reshape without calling
+      // `isProperAncestor`
+      return rewriter.notifyMatchFailure(
+          dim, "dim.getIndex does not dominate reshape.");
+    }
 
     // Place the load directly after the reshape to ensure that the shape memref
     // was not mutated.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a854da466c3130..dc8843aa4e1e13 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Since tensors are immutable we don't need to worry about where to place
+    // the extract call
+    rewriter.setInsertionPointAfter(dim);
+    Location loc = dim.getLoc();
+    Value extract =
+        rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+    if (extract.getType() != dim.getType())
+      extract =
+          rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
+    rewriter.replaceOp(dim, extract);
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b1e92e54d561da..506ed1f1c10b10 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
 
 // -----
 
+// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
+//  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:   %[[SHP:[0-9a-z]+]]: memref<?xindex>,
+//  CHECK-SAME:   %[[IDX:[0-9a-z]+]]: index
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+  %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+  %dim = memref.dim %reshape, %arg2 : memref<*xf32>
+  return %dim : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_for(
+//       CHECK: memref.reshape
+//       CHECK: memref.dim
+//   CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = memref.dim %0, %arg2 : memref<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
+//       CHECK: memref.reshape
+//       CHECK: memref.dim
+//   CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = memref.dim %reshape, %0 : memref<*xf32>
+    return %dim : index
+  }
+
+// -----
+
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 70f5d61bd802fd..e5374f031be553 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK:         return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   tensor.store
+//   CHECK-NOT:   tensor.dim
+//   CHECK-NOT: tensor.reshape
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = tensor.reshape %arg0(%arg1)
+      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Update the shape to test that the load ends up in the right place.
+  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+  %1 = tensor.dim %0, %c3 : tensor<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+//       CHECK:  tensor.extract
+//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
+//   CHECK-NOT:  tensor.dim
+//   CHECK-NOT:  tensor.reshape
+//       CHECK:  return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+    -> index {
+    %c3 = arith.constant 3 : index
+    %0 = tensor.reshape %arg0(%arg1)
+        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    %1 = tensor.dim %0, %c3 : tensor<*xf32>
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+//       CHECK: scf.for
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+//       CHECK: arith.muli
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+    return %dim : index
+  }


        


More information about the Mlir-commits mailing list