[Mlir-commits] [mlir] 26722f5 - [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (#84225)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 11 19:37:37 PDT 2024
Author: Sayan Saha
Date: 2024-03-11T19:37:33-07:00
New Revision: 26722f5b61575fb0e58ff2933e7bea03353ff441
URL: https://github.com/llvm/llvm-project/commit/26722f5b61575fb0e58ff2933e7bea03353ff441
DIFF: https://github.com/llvm/llvm-project/commit/26722f5b61575fb0e58ff2933e7bea03353ff441.diff
LOG: [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (#84225)
The current canonicalization of `memref.dim` operating on the result of
`memref.reshape` into `memref.load` is incorrect as it doesn't check
whether the `index` operand of `memref.dim` dominates the source
`memref.reshape` op. It always introduces `memref.load` right after
`memref.reshape` to ensure the `memref` is not mutated before the
`memref.load` call. As a result, the following error is observed:
```
$> mlir-opt --canonicalize input.mlir
func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}
```
results in:
```
dominator.mlir:22:12: error: operand #1 does not dominate this use
%dim = memref.dim %reshape, %0 : memref<*xf32>
^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
%0 = arith.muli %arg2, %c4 : index
```
Properly fixing this issue requires a dominator analysis which is
expensive to run within a canonicalization pattern. So, this patch fixes
the canonicalization pattern by being more strict/conservative about the
legality condition in which we perform this canonicalization.
The more general pattern is also added to `tensor.dim`. Since tensors are
immutable we don't need to worry about where to introduce the
`tensor.extract` call after canonicalization.
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94e0ed319cae83..836dcb8f329e70 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
if (!reshape)
- return failure();
+ return rewriter.notifyMatchFailure(
+ dim, "Dim op is not defined by a reshape op.");
+
+ // dim of a memref reshape can be folded if dim.getIndex() dominates the
+ // reshape. Instead of using `DominanceInfo` (which is usually costly) we
+ // cheaply check that either of the following conditions hold:
+ // 1. dim.getIndex() is defined in the same block as reshape but before
+ // reshape.
+ // 2. dim.getIndex() is defined in a parent block of
+ // reshape.
+
+ // Check condition 1
+ if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
+ if (auto *definingOp = dim.getIndex().getDefiningOp()) {
+ if (reshape->isBeforeInBlock(definingOp)) {
+ return rewriter.notifyMatchFailure(
+ dim,
+ "dim.getIndex is not defined before reshape in the same block.");
+ }
+ } // else dim.getIndex is a block argument to reshape->getBlock and
+ // dominates reshape
+ } // Check condition 2
+ else if (dim->getBlock() != reshape->getBlock() &&
+ !dim.getIndex().getParentRegion()->isProperAncestor(
+ reshape->getParentRegion())) {
+ // If dim and reshape are in the same block but dim.getIndex() isn't, we
+ // already know dim.getIndex() dominates reshape without calling
+ // `isProperAncestor`
+ return rewriter.notifyMatchFailure(
+ dim, "dim.getIndex does not dominate reshape.");
+ }
// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a854da466c3130..dc8843aa4e1e13 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Since tensors are immutable we don't need to worry about where to place
+ // the extract call
+ rewriter.setInsertionPointAfter(dim);
+ Location loc = dim.getLoc();
+ Value extract =
+ rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ if (extract.getType() != dim.getType())
+ extract =
+ rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
+ rewriter.replaceOp(dim, extract);
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b1e92e54d561da..506ed1f1c10b10 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
// -----
+// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
+// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NOT: memref.dim
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+ %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ %dim = memref.dim %reshape, %arg2 : memref<*xf32>
+ return %dim : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_for(
+// CHECK: memref.reshape
+// CHECK: memref.dim
+// CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = memref.dim %0, %arg2 : memref<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
+// CHECK: memref.reshape
+// CHECK: memref.dim
+// CHECK-NOT: memref.load
+func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = memref.dim %reshape, %0 : memref<*xf32>
+ return %dim : index
+ }
+
+// -----
+
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 70f5d61bd802fd..e5374f031be553 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+// CHECK-NOT: tensor.store
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // Update the shape to test that the load ends up in the right place.
+ tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+// CHECK: tensor.extract
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+// CHECK: scf.for
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+// CHECK: arith.muli
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+ return %dim : index
+ }
More information about the Mlir-commits
mailing list